diff --git a/api/graphql/schema.py b/api/graphql/schema.py index d0ac1619f..1b7213642 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -18,8 +18,10 @@ from api.graphql.filters import GraphQLFilter, GraphQLMetaFilter from api.graphql.loaders import LoaderKeys, get_context from db.python import enum_tables +from db.python.gcp_connect import BqConnection from db.python.layers import AnalysisLayer, SampleLayer, SequencingGroupLayer from db.python.layers.assay import AssayLayer +from db.python.layers.billing import BillingLayer from db.python.layers.family import FamilyLayer from db.python.tables.analysis import AnalysisFilter from db.python.tables.assay import AssayFilter @@ -32,6 +34,9 @@ AnalysisInternal, AssayInternal, AuditLogInternal, + BillingColumn, + BillingInternal, + BillingTotalCostQueryModel, FamilyInternal, ParticipantInternal, Project, @@ -497,7 +502,6 @@ class GraphQLSequencingGroup: @staticmethod def from_internal(internal: SequencingGroupInternal) -> 'GraphQLSequencingGroup': - # print(internal) return GraphQLSequencingGroup( id=sequencing_group_id_format(internal.id), type=internal.type, @@ -593,6 +597,33 @@ async def sample(self, info: Info, root: 'GraphQLAssay') -> GraphQLSample: return GraphQLSample.from_internal(sample) +@strawberry.type +class GraphQLBilling: + """GraphQL Billing""" + + id: str | None + ar_guid: str | None + gcp_project: str | None + topic: str | None + batch_id: str | None + cost_category: str | None + day: datetime.date | None + cost: float | None + + @staticmethod + def from_internal(internal: BillingInternal) -> 'GraphQLBilling': + return GraphQLBilling( + id=internal.id, + ar_guid=internal.ar_guid, + gcp_project=internal.gcp_project, + topic=internal.topic, + batch_id=internal.batch_id, + cost_category=internal.cost_category, + day=internal.day, + cost=internal.cost, + ) + + @strawberry.type class Query: """GraphQL Queries""" @@ -731,6 +762,81 @@ async def my_projects(self, info: Info) -> list[GraphQLProject]: ) return [GraphQLProject.from_internal(p) for p in projects] + @strawberry.field + async def billing( + self, + info: Info, + batch_id: str | None = None, + ar_guid: str | None = None, + topic: str | None = None, + gcp_project: str | None = None, + day: GraphQLFilter[datetime.datetime] | None = None, + cost: GraphQLFilter[float] | None = None, + ) -> list[GraphQLBilling]: + """ + This is the first raw implementation of Billing inside GraphQL + """ + # TODO check billing is enabled e.g.: + # if not is_billing_enabled(): + # raise ValueError('Billing is not enabled') + + # TODO is there a better way to get the BQ connection? + connection = info.context['connection'] + bg_connection = BqConnection(connection.author) + slayer = BillingLayer(bg_connection) + + if ar_guid: + res = await slayer.get_cost_by_ar_guid(ar_guid) + if res: + # only show the costs + res = res.costs + + elif batch_id: + res = await slayer.get_cost_by_batch_id(batch_id) + if res: + # only show the costs + res = res.costs + + else: + # TODO construct fields from request.body (selected attributes) + # For time being, just use these fields + fields = [ + BillingColumn.DAY, + BillingColumn.COST, + BillingColumn.COST_CATEGORY, + ] + + filters = {} + if topic: + filters['topic'] = topic + fields.append(BillingColumn.TOPIC) + if gcp_project: + filters['gcp_project'] = gcp_project + fields.append(BillingColumn.GCP_PROJECT) + + if day: + all_days_vals = day.all_values() + start_date = min(all_days_vals).strftime('%Y-%m-%d') + end_date = max(all_days_vals).strftime('%Y-%m-%d') + else: + # TODO we need to limit to small time periods to avoid huge charges + # If day is not selected use only current day records + start_date = datetime.datetime.now().strftime('%Y-%m-%d') + end_date = start_date + + query = BillingTotalCostQueryModel( + fields=fields, + start_date=start_date, + end_date=end_date, + filters=filters, + ) + res = await slayer.get_total_cost(query) + + return [ + GraphQLBilling.from_internal(BillingInternal.from_db(**dict(p))) + for p in res + ] + schema = strawberry.Schema( query=Query, mutation=None, extensions=[QueryDepthLimiter(max_depth=10)] diff --git a/api/routes/billing.py b/api/routes/billing.py index f903717d6..be1fb46ff 100644 --- a/api/routes/billing.py +++ b/api/routes/billing.py @@ -7,7 +7,7 @@ from api.settings import BILLING_CACHE_RESPONSE_TTL, BQ_AGGREG_VIEW from api.utils.db import BqConnection, get_author from db.python.layers.billing import BillingLayer -from models.models.billing import ( +from models.models import ( BillingColumn, BillingCostBudgetRecord, BillingHailBatchCostRecord, diff --git a/db/python/layers/billing.py b/db/python/layers/billing.py index 89f71e155..8f991c728 100644 --- a/db/python/layers/billing.py +++ b/db/python/layers/billing.py @@ -7,13 +7,11 @@ from models.models import ( BillingColumn, BillingCostBudgetRecord, - BillingTotalCostQueryModel, -) -from models.models.billing import ( BillingHailBatchCostRecord, BillingSource, BillingTimeColumn, BillingTimePeriods, + BillingTotalCostQueryModel, ) diff --git a/db/python/tables/bq/billing_ar_batch.py b/db/python/tables/bq/billing_ar_batch.py index be2d3093e..d9326b6b3 100644 --- a/db/python/tables/bq/billing_ar_batch.py +++ b/db/python/tables/bq/billing_ar_batch.py @@ -30,7 +30,7 @@ async def get_batches_by_ar_guid( WHERE ar_guid = @ar_guid AND batch_id IS NOT NULL GROUP BY batch_id - ORDER BY 1; + ORDER BY batch_id; """ query_parameters = [ diff --git a/db/python/tables/bq/billing_base.py b/db/python/tables/bq/billing_base.py index 8bb1158b4..12e56f02b 100644 --- a/db/python/tables/bq/billing_base.py +++ b/db/python/tables/bq/billing_base.py @@ -9,6 +9,9 @@ from api.settings import BQ_BUDGET_VIEW, BQ_DAYS_BACK_OPTIMAL from api.utils.dates import get_invoice_month_range, reformat_datetime from db.python.gcp_connect import BqDbBase +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.function_bq_filter import FunctionBQFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter from models.models import ( BillingColumn, BillingCostBudgetRecord, @@ -22,7 +25,9 @@ # Day Time details used in grouping and parsing formulas -TimeGroupingDetails = namedtuple('TimeGroupingDetails', ['field', 'formula']) +TimeGroupingDetails = namedtuple( + 'TimeGroupingDetails', ['field', 'formula', 'separator'] +) def abbrev_cost_category(cost_category: str) -> str: @@ -35,64 +40,37 @@ def prepare_time_periods( ) -> TimeGroupingDetails: """Prepare Time periods grouping and parsing formulas""" time_column = query.time_column or 'day' - result = TimeGroupingDetails('', '') + result = TimeGroupingDetails('', '', '') # Based on specified time period, add the corresponding column if query.time_periods == BillingTimePeriods.DAY: result = TimeGroupingDetails( - field=f'FORMAT_DATE("%Y-%m-%d", {time_column}) as day, ', - formula='PARSE_DATE("%Y-%m-%d", day) as day, ', + field=f'FORMAT_DATE("%Y-%m-%d", {time_column}) as day', + formula='PARSE_DATE("%Y-%m-%d", day) as day', + separator=',', ) elif query.time_periods == BillingTimePeriods.WEEK: result = TimeGroupingDetails( - field=f'FORMAT_DATE("%Y%W", {time_column}) as day, ', - formula='PARSE_DATE("%Y%W", day) as day, ', + field=f'FORMAT_DATE("%Y%W", {time_column}) as day', + formula='PARSE_DATE("%Y%W", day) as day', + separator=',', ) elif query.time_periods == BillingTimePeriods.MONTH: result = TimeGroupingDetails( - field=f'FORMAT_DATE("%Y%m", {time_column}) as day, ', - formula='PARSE_DATE("%Y%m", day) as day, ', + field=f'FORMAT_DATE("%Y%m", {time_column}) as day', + formula='PARSE_DATE("%Y%m", day) as day', + separator=',', ) elif query.time_periods == BillingTimePeriods.INVOICE_MONTH: result = TimeGroupingDetails( - field='invoice_month as day, ', formula='PARSE_DATE("%Y%m", day) as day, ' + field='invoice_month as day', + formula='PARSE_DATE("%Y%m", day) as day', + separator=',', ) return result -def construct_filter( - name: str, value: Any, is_label: bool = False -) -> tuple[str, bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter]: - """Based on Filter value, construct filter string and query parameter - - Args: - name (str): Filter name - value (Any): Filter value - is_label (bool, optional): Is filter a label?. Defaults to False. - - Returns: - tuple[str, bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter] - """ - compare = '=' - b1, b2 = '', '' - param_type = bigquery.ScalarQueryParameter - key = name.replace('-', '_') - - if isinstance(value, list): - compare = 'IN' - b1, b2 = 'UNNEST(', ')' - param_type = bigquery.ArrayQueryParameter - - if is_label: - name = f'getLabelValue(labels, "{name}")' - - return ( - f'{name} {compare} {b1}@{key}{b2}', - param_type(key, 'STRING', value), - ) - - class BillingBaseTable(BqDbBase): """Billing Base Table This is abstract class, it should not be instantiated @@ -124,6 +102,26 @@ def _execute_query( # otherwise return as BQ iterator return self._connection.connection.query(query, job_config=job_config) + def _query_to_partitioned_filter( + self, query: BillingTotalCostQueryModel + ) -> BillingFilter: + """ + By default views are partitioned by 'day', + if different then overwrite in the subclass + """ + billing_filter = query.to_filter() + + # initial partition filter + billing_filter.day = GenericBQFilter[datetime]( + gte=datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None, + lte=datetime.strptime(query.end_date, '%Y-%m-%d') + if query.end_date + else None, + ) + return billing_filter + def _filter_to_optimise_query(self) -> str: """Filter string to optimise BQ query""" return 'day >= TIMESTAMP(@start_day) AND day <= TIMESTAMP(@last_day)' @@ -132,67 +130,6 @@ def _last_loaded_day_filter(self) -> str: """Last Loaded day filter string""" return 'day = TIMESTAMP(@last_loaded_day)' - def _prepare_time_filters(self, query: BillingTotalCostQueryModel): - """Prepare time filters""" - time_column = query.time_column or 'day' - time_filters = [] - query_parameters = [] - - if query.start_date: - time_filters.append(f'{time_column} >= TIMESTAMP(@start_date)') - query_parameters.extend( - [ - bigquery.ScalarQueryParameter( - 'start_date', 'STRING', query.start_date - ), - ] - ) - if query.end_date: - time_filters.append(f'{time_column} <= TIMESTAMP(@end_date)') - query_parameters.extend( - [ - bigquery.ScalarQueryParameter('end_date', 'STRING', query.end_date), - ] - ) - - return time_filters, query_parameters - - def _prepare_filter_str(self, query: BillingTotalCostQueryModel): - """Prepare filter string""" - and_filters, query_parameters = self._prepare_time_filters(query) - - # No additional filters - filters = [] - if not query.filters: - filter_str = 'WHERE ' + ' AND '.join(and_filters) if and_filters else '' - return filter_str, query_parameters - - # Add each of the filters in the query - for filter_key, filter_value in query.filters.items(): - col_name = str(filter_key.value) - - if not isinstance(filter_value, dict): - filter_, query_param = construct_filter(col_name, filter_value) - filters.append(filter_) - query_parameters.append(query_param) - else: - for label_key, label_value in filter_value.items(): - filter_, query_param = construct_filter( - label_key, label_value, True - ) - filters.append(filter_) - query_parameters.append(query_param) - - if query.filters_op == 'OR': - if filters: - and_filters.append('(' + ' OR '.join(filters) + ')') - else: - # if not specified, default to AND - and_filters.extend(filters) - - filter_str = 'WHERE ' + ' AND '.join(and_filters) if and_filters else '' - return filter_str, query_parameters - def _convert_output(self, query_job_result): """Convert query result to json""" if not query_job_result or query_job_result.result().total_rows == 0: @@ -483,70 +420,138 @@ async def _append_running_cost_records( return results - async def get_total_cost( - self, - query: BillingTotalCostQueryModel, - ) -> list[dict] | None: - """ - Get Total cost of selected fields for requested time interval from BQ view - """ - if not query.start_date or not query.end_date or not query.fields: - raise ValueError('Date and Fields are required') + def _prepare_order_by_string( + self, order_by: dict[BillingColumn, bool] | None + ) -> str: + """Prepare order by string""" + if not order_by: + return '' + + order_by_cols = [] + for order_field, reverse in order_by.items(): + col_name = str(order_field.value) + col_order = 'DESC' if reverse else 'ASC' + order_by_cols.append(f'{col_name} {col_order}') + + return f'ORDER BY {",".join(order_by_cols)}' if order_by_cols else '' + + def _prepare_aggregation( + self, query: BillingTotalCostQueryModel + ) -> tuple[str, str]: + """Prepare both fields for aggregation and group by string""" + # Get columns to group by + + # if group by is populated, then we need to group by day as well + grp_columns = ['day'] if query.group_by else [] - # Get columns to group by and check view to use - grp_columns = [] for field in query.fields: col_name = str(field.value) if not BillingColumn.can_group_by(field): # if the field cannot be grouped by, skip it continue + # append to potential columns to group by grp_columns.append(col_name) - grp_selected = ','.join(grp_columns) fields_selected = ','.join( (field.value for field in query.fields if field != BillingColumn.COST) ) + grp_selected = ','.join(grp_columns) + group_by = f'GROUP BY {grp_selected}' if query.group_by else '' + + return fields_selected, group_by + + def _prepare_labels_function(self, query: BillingTotalCostQueryModel): + if not query.filters: + return None + + if BillingColumn.LABELS in query.filters and isinstance( + query.filters[BillingColumn.LABELS], dict + ): + # prepare labels as function filters, parameterized both sides + func_filter = FunctionBQFilter( + name='getLabelValue', + implementation=""" + CREATE TEMP FUNCTION getLabelValue( + labels ARRAY>, label STRING + ) AS ( + (SELECT value FROM UNNEST(labels) WHERE key = label LIMIT 1) + ); + """, + ) + func_filter.to_sql( + BillingColumn.LABELS, + query.filters[BillingColumn.LABELS], + query.filters_op, + ) + return func_filter + + # otherwise + return None + + async def get_total_cost( + self, + query: BillingTotalCostQueryModel, + ) -> list[dict] | None: + """ + Get Total cost of selected fields for requested time interval from BQ views + """ + if not query.start_date or not query.end_date or not query.fields: + raise ValueError('Date and Fields are required') + + # Get columns to select and to group by + fields_selected, group_by = self._prepare_aggregation(query) + + # construct order by + order_by_str = self._prepare_order_by_string(query.order_by) + # prepare grouping by time periods - time_group = TimeGroupingDetails('', '') + time_group = TimeGroupingDetails('', '', '') if query.time_periods or query.time_column: - # remove existing day column, if added to fields - # this is to prevent duplicating various time periods in one query - # if BillingColumn.DAY in query.fields: - # columns.remove(BillingColumn.DAY) time_group = prepare_time_periods(query) - filter_str, query_parameters = self._prepare_filter_str(query) + # overrides time specific fields with relevant time column name + query_filter = self._query_to_partitioned_filter(query) - # construct order by - order_by_cols = [] - if query.order_by: - for order_field, reverse in query.order_by.items(): - col_name = str(order_field.value) - col_order = 'DESC' if reverse else 'ASC' - order_by_cols.append(f'{col_name} {col_order}') + # prepare where string and SQL parameters + where_str, sql_parameters = query_filter.to_sql() + + # extract only BQ Query parameter, keys are not used in BQ SQL + # have to declare empty list first as linting is not happy + query_parameters: list[ + bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter + ] = [] + query_parameters.extend(sql_parameters.values()) + + # prepare labels as function filters if present + func_filter = self._prepare_labels_function(query) + if func_filter: + # extend where_str and query_parameters + query_parameters.extend(func_filter.func_sql_parameters) + + # now join Prepared Where with Labels Function Where + where_str = ' AND '.join([where_str, func_filter.func_where]) - order_by_str = f'ORDER BY {",".join(order_by_cols)}' if order_by_cols else '' + # if group by is populated, then we need SUM the cost, otherwise raw cost + cost_column = 'SUM(cost) as cost' if query.group_by else 'cost' - group_by = f'GROUP BY day, {grp_selected}' if query.group_by else '' - cost = 'SUM(cost) as cost' if query.group_by else 'cost' + if where_str: + # Where is not empty, prepend with WHERE + where_str = f'WHERE {where_str}' _query = f""" - CREATE TEMP FUNCTION getLabelValue( - labels ARRAY>, label STRING - ) AS ( - (SELECT value FROM UNNEST(labels) WHERE key = label LIMIT 1) - ); + {func_filter.fun_implementation if func_filter else ''} WITH t AS ( - SELECT {time_group.field}{fields_selected}, {cost} + SELECT {time_group.field}{time_group.separator} {fields_selected}, + {cost_column} FROM `{self.get_table_name()}` - {filter_str} + {where_str} {group_by} {order_by_str} ) - SELECT {time_group.formula}{fields_selected}, cost FROM t + SELECT {time_group.formula}{time_group.separator} {fields_selected}, cost FROM t """ # append min cost condition diff --git a/db/python/tables/bq/billing_filter.py b/db/python/tables/bq/billing_filter.py new file mode 100644 index 000000000..1c333c767 --- /dev/null +++ b/db/python/tables/bq/billing_filter.py @@ -0,0 +1,49 @@ +import dataclasses +import datetime + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel + + +@dataclasses.dataclass +class BillingFilter(GenericBQFilterModel): + """ + Filter for billing, contains all possible attributes to filter on + """ + + # partition specific filters: + + # most billing views are parttioned by day + day: GenericBQFilter[datetime.datetime] = None + + # gpc table has different partition field: part_time + part_time: GenericBQFilter[datetime.datetime] = None + + # aggregate has different partition field: usage_end_time + usage_end_time: GenericBQFilter[datetime.datetime] = None + + # common filters: + invoice_month: GenericBQFilter[str] = None + + # min cost e.g. 0.01, if not set, will show all + cost: GenericBQFilter[float] = None + + ar_guid: GenericBQFilter[str] = None + gcp_project: GenericBQFilter[str] = None + topic: GenericBQFilter[str] = None + batch_id: GenericBQFilter[str] = None + cost_category: GenericBQFilter[str] = None + sku: GenericBQFilter[str] = None + dataset: GenericBQFilter[str] = None + sequencing_type: GenericBQFilter[str] = None + stage: GenericBQFilter[str] = None + sequencing_group: GenericBQFilter[str] = None + compute_category: GenericBQFilter[str] = None + cromwell_sub_workflow_name: GenericBQFilter[str] = None + cromwell_workflow_id: GenericBQFilter[str] = None + goog_pipelines_worker: GenericBQFilter[str] = None + wdl_task_name: GenericBQFilter[str] = None + namespace: GenericBQFilter[str] = None + + def __hash__(self): + return super().__hash__() diff --git a/db/python/tables/bq/billing_gcp_daily.py b/db/python/tables/bq/billing_gcp_daily.py index 1924b01e6..1588ab65f 100644 --- a/db/python/tables/bq/billing_gcp_daily.py +++ b/db/python/tables/bq/billing_gcp_daily.py @@ -1,7 +1,11 @@ +from datetime import datetime, timedelta + from google.cloud import bigquery from api.settings import BQ_DAYS_BACK_OPTIMAL, BQ_GCP_BILLING_VIEW from db.python.tables.bq.billing_base import BillingBaseTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter from models.models import BillingTotalCostQueryModel @@ -14,50 +18,28 @@ def get_table_name(self): """Get table name""" return self.table_name - def _filter_to_optimise_query(self) -> str: - """Filter string to optimise BQ query - override base class method as gcp table has different partition field + def _query_to_partitioned_filter( + self, query: BillingTotalCostQueryModel + ) -> BillingFilter: """ - # add extra filter to limit materialized view partition - # Raw BQ billing table is partitioned by part_time (when data are loaded) - # and not by end of usage time (day) - # There is a delay up to 4-5 days between part_time and day - # 7 days is added to be sure to get all data - return ( - 'part_time >= TIMESTAMP(@start_day)' - 'AND part_time <= TIMESTAMP_ADD(TIMESTAMP(@last_day), INTERVAL 7 DAY)' - ) - - def _last_loaded_day_filter(self) -> str: - """Filter string to optimise BQ query - override base class method as gcp table has different partition field + add extra filter to limit materialized view partition + Raw BQ billing table is partitioned by part_time (when data are loaded) + and not by end of usage time (day) + There is a delay up to 4-5 days between part_time and day + 7 days is added to be sure to get all data """ - # add extra filter to limit materialized view partition - # Raw BQ billing table is partitioned by part_time (when data are loaded) - # and not by end of usage time (day) - # There is a delay up to 4-5 days between part_time and day - # 7 days is added to be sure to get all data - return ( - 'day = TIMESTAMP(@last_loaded_day)' - 'AND part_time >= TIMESTAMP(@last_loaded_day)' - 'AND part_time <= TIMESTAMP_ADD(TIMESTAMP(@last_loaded_day),INTERVAL 7 DAY)' - ) - - def _prepare_time_filters(self, query: BillingTotalCostQueryModel): - """Prepare time filters, append to time_filters list""" - time_filters, query_parameters = super()._prepare_time_filters(query) - - # BQ_GCP_BILLING_VIEW view is partitioned by different field - # BQ has limitation, materialized view can only by partition by base table - # partition or its subset, in our case _PARTITIONTIME - # (part_time field in the view) - # We are querying by day, - # which can be up to a week behind regarding _PARTITIONTIME - time_filters.append('part_time >= TIMESTAMP(@start_date)') - time_filters.append( - 'part_time <= TIMESTAMP_ADD(TIMESTAMP(@end_date), INTERVAL 7 DAY)' + billing_filter = query.to_filter() + + # initial partition filter + billing_filter.part_time = GenericBQFilter[datetime]( + gte=datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None, + lte=(datetime.strptime(query.end_date, '%Y-%m-%d') + timedelta(days=7)) + if query.end_date + else None, ) - return time_filters, query_parameters + return billing_filter async def _last_loaded_day(self): """Get the most recent fully loaded day in db diff --git a/db/python/tables/bq/billing_raw.py b/db/python/tables/bq/billing_raw.py index 6a6c7b83e..a82fa4eec 100644 --- a/db/python/tables/bq/billing_raw.py +++ b/db/python/tables/bq/billing_raw.py @@ -1,5 +1,10 @@ +from datetime import datetime + from api.settings import BQ_AGGREG_RAW from db.python.tables.bq.billing_base import BillingBaseTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from models.models import BillingTotalCostQueryModel class BillingRawTable(BillingBaseTable): @@ -10,3 +15,22 @@ class BillingRawTable(BillingBaseTable): def get_table_name(self): """Get table name""" return self.table_name + + def _query_to_partitioned_filter( + self, query: BillingTotalCostQueryModel + ) -> BillingFilter: + """ + Raw BQ billing table is partitioned by usage_end_time + """ + billing_filter = query.to_filter() + + # initial partition filter + billing_filter.usage_end_time = GenericBQFilter[datetime]( + gte=datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None, + lte=datetime.strptime(query.end_date, '%Y-%m-%d') + if query.end_date + else None, + ) + return billing_filter diff --git a/db/python/tables/bq/function_bq_filter.py b/db/python/tables/bq/function_bq_filter.py new file mode 100644 index 000000000..07e89e1eb --- /dev/null +++ b/db/python/tables/bq/function_bq_filter.py @@ -0,0 +1,109 @@ +from datetime import datetime +from enum import Enum +from typing import Any + +from google.cloud import bigquery + +from models.models import BillingColumn + + +class FunctionBQFilter: + """ + Function BigQuery filter where left site is a function call + In such case we need to parameterised values on both side of SQL + E.g. + + SELECT ... + FROM ... + WHERE getLabelValue(labels, 'batch_id') = '1234' + + In this case we have 2 string values which need to be parameterised + """ + + func_where = '' + func_sql_parameters: list[ + bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter + ] = [] + + def __init__(self, name: str, implementation: str): + self.func_name = name + self.fun_implementation = implementation + # param_id is a counter for parameterised values + self._param_id = 0 + + def to_sql( + self, + column_name: BillingColumn, + func_params: str | list[Any] | dict[Any, Any], + func_operator: str = None, + ) -> tuple[str, list[bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter]]: + """ + creates the left side of where : FUN(column_name, @params) + each of func_params convert to BQ parameter + combined multiple calls with provided operator, + if func_operator is None then AND is assumed by default + """ + values = [] + conditionals = [] + + if not isinstance(func_params, dict): + # Ignore func_params which are not dictionary for the time being + return '', [] + + for param_key, param_value in func_params.items(): + # parameterised both param_key and param_value + # e.g. this is raw SQL example: + # getLabelValue(labels, {param_key}) = {param_value} + self._param_id += 1 + key = f'param{self._param_id}' + val = f'value{self._param_id}' + # add param_key as parameterised BQ value + values.append(FunctionBQFilter._sql_value_prep(key, param_key)) + + # add param_value as parameterised BQ value + values.append(FunctionBQFilter._sql_value_prep(val, param_value)) + + # format as FUN(column_name, @param) = @value + conditionals.append( + ( + f'{self.func_name}({column_name.value},@{key}) = ' + f'{FunctionBQFilter._sql_cond_prep(val, param_value)}' + ) + ) + + if func_operator and func_operator == 'OR': + condition = ' OR '.join(conditionals) + else: + condition = ' AND '.join(conditionals) + + # set the class variables for later use + self.func_where = f'({condition})' + self.func_sql_parameters = values + return self.func_where, self.func_sql_parameters + + @staticmethod + def _sql_cond_prep(key: str, value: Any) -> str: + """ + By default '{key}' is used, + but for datetime it has to be wrapped in TIMESTAMP({key}) + """ + if isinstance(value, datetime): + return f'TIMESTAMP(@{key})' + + # otherwise as default + return f'@{key}' + + @staticmethod + def _sql_value_prep(key: str, value: Any) -> bigquery.ScalarQueryParameter: + """ """ + if isinstance(value, Enum): + return bigquery.ScalarQueryParameter(key, 'STRING', value.value) + if isinstance(value, int): + return bigquery.ScalarQueryParameter(key, 'INT64', value) + if isinstance(value, float): + return bigquery.ScalarQueryParameter(key, 'FLOAT64', value) + if isinstance(value, datetime): + return bigquery.ScalarQueryParameter(key, 'STRING', value) + + # otherwise as string parameter + return bigquery.ScalarQueryParameter(key, 'STRING', value) diff --git a/db/python/tables/bq/generic_bq_filter.py b/db/python/tables/bq/generic_bq_filter.py new file mode 100644 index 000000000..7e75b7d00 --- /dev/null +++ b/db/python/tables/bq/generic_bq_filter.py @@ -0,0 +1,101 @@ +from datetime import datetime +from enum import Enum +from typing import Any + +from google.cloud import bigquery + +from db.python.utils import GenericFilter, T + + +class GenericBQFilter(GenericFilter[T]): + """ + Generic BigQuery filter is BQ specific filter class, based on GenericFilter + """ + + def to_sql( + self, column: str, column_name: str = None + ) -> tuple[str, dict[str, T | list[T] | Any | list[Any]]]: + """ + Convert to SQL, and avoid SQL injection + + """ + conditionals = [] + values: dict[str, T | list[T] | Any | list[Any]] = {} + _column_name = column_name or column + + if not isinstance(column, str): + raise ValueError(f'Column {_column_name!r} must be a string') + if self.eq is not None: + k = self.generate_field_name(_column_name + '_eq') + conditionals.append(f'{column} = {self._sql_cond_prep(k, self.eq)}') + values[k] = self._sql_value_prep(k, self.eq) + if self.in_ is not None: + if not isinstance(self.in_, list): + raise ValueError('IN filter must be a list') + if len(self.in_) == 1: + k = self.generate_field_name(_column_name + '_in_eq') + conditionals.append(f'{column} = {self._sql_cond_prep(k, self.in_[0])}') + values[k] = self._sql_value_prep(k, self.in_[0]) + else: + k = self.generate_field_name(_column_name + '_in') + conditionals.append(f'{column} IN ({self._sql_cond_prep(k, self.in_)})') + values[k] = self._sql_value_prep(k, self.in_) + if self.nin is not None: + if not isinstance(self.nin, list): + raise ValueError('NIN filter must be a list') + k = self.generate_field_name(column + '_nin') + conditionals.append(f'{column} NOT IN ({self._sql_cond_prep(k, self.nin)})') + values[k] = self._sql_value_prep(k, self.nin) + if self.gt is not None: + k = self.generate_field_name(column + '_gt') + conditionals.append(f'{column} > {self._sql_cond_prep(k, self.gt)}') + values[k] = self._sql_value_prep(k, self.gt) + if self.gte is not None: + k = self.generate_field_name(column + '_gte') + conditionals.append(f'{column} >= {self._sql_cond_prep(k, self.gte)}') + values[k] = self._sql_value_prep(k, self.gte) + if self.lt is not None: + k = self.generate_field_name(column + '_lt') + conditionals.append(f'{column} < {self._sql_cond_prep(k, self.lt)}') + values[k] = self._sql_value_prep(k, self.lt) + if self.lte is not None: + k = self.generate_field_name(column + '_lte') + conditionals.append(f'{column} <= {self._sql_cond_prep(k, self.lte)}') + values[k] = self._sql_value_prep(k, self.lte) + + return ' AND '.join(conditionals), values + + @staticmethod + def _sql_cond_prep(key, value) -> str: + """ + By default '@{key}' is used, + but for datetime it has to be wrapped in TIMESTAMP(@{k}) + """ + if isinstance(value, datetime): + return f'TIMESTAMP(@{key})' + + # otherwise as default + return f'@{key}' + + @staticmethod + def _sql_value_prep(key, value): + """ + Overrides the default _sql_value_prep to handle BQ parameters + """ + if isinstance(value, list): + return bigquery.ArrayQueryParameter( + key, 'STRING', ','.join([str(v) for v in value]) + ) + if isinstance(value, Enum): + return bigquery.ScalarQueryParameter(key, 'STRING', value.value) + if isinstance(value, int): + return bigquery.ScalarQueryParameter(key, 'INT64', value) + if isinstance(value, float): + return bigquery.ScalarQueryParameter(key, 'FLOAT64', value) + if isinstance(value, datetime): + return bigquery.ScalarQueryParameter( + key, 'STRING', value.strftime('%Y-%m-%d %H:%M:%S') + ) + + # otherwise as string parameter + return bigquery.ScalarQueryParameter(key, 'STRING', value) diff --git a/db/python/tables/bq/generic_bq_filter_model.py b/db/python/tables/bq/generic_bq_filter_model.py new file mode 100644 index 000000000..3de5051af --- /dev/null +++ b/db/python/tables/bq/generic_bq_filter_model.py @@ -0,0 +1,109 @@ +import dataclasses +from typing import Any + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import GenericFilterModel + + +def prepare_bq_query_from_dict_field( + filter_, field_name, column_name +) -> tuple[list[str], dict[str, Any]]: + """ + Prepare a SQL query from a dict field, which is a dict of GenericFilters. + Usually this is a JSON field in the database that we want to query on. + """ + conditionals: list[str] = [] + values: dict[str, Any] = {} + for key, value in filter_.items(): + if not isinstance(value, GenericBQFilter): + raise ValueError(f'Filter {field_name} must be a GenericFilter') + if '"' in key: + raise ValueError('Meta key contains " character, which is not allowed') + fconditionals, fvalues = value.to_sql( + f"JSON_EXTRACT({column_name}, '$.{key}')", + column_name=f'{column_name}_{key}', + ) + conditionals.append(fconditionals) + values.update(fvalues) + + return conditionals, values + + +@dataclasses.dataclass(kw_only=True) +class GenericBQFilterModel(GenericFilterModel): + """ + Class that contains fields of GenericBQFilters that can be used to filter + """ + + def __post_init__(self): + for field in dataclasses.fields(self): + value = getattr(self, field.name) + if value is None: + continue + + if isinstance(value, tuple) and len(value) == 1 and value[0] is None: + raise ValueError( + 'There is very likely a trailing comma on the end of ' + f'{self.__class__.__name__}.{field.name}. If you actually want a ' + 'tuple of length one with the value = (None,), then use ' + 'dataclasses.field(default_factory=lambda: (None,))' + ) + if isinstance(value, GenericBQFilter): + continue + + if isinstance(value, dict): + # make sure each field is a GenericFilter, or set it to be one, + # in this case it's always 'eq', never automatically in_ + new_value = { + k: v if isinstance(v, GenericBQFilter) else GenericBQFilter(eq=v) + for k, v in value.items() + } + setattr(self, field.name, new_value) + continue + + # lazily provided a value, which we'll correct + if isinstance(value, list): + setattr(self, field.name, GenericBQFilter(in_=value)) + else: + setattr(self, field.name, GenericBQFilter(eq=value)) + + def to_sql( + self, field_overrides: dict[str, Any] = None + ) -> tuple[str, dict[str, Any]]: + """Convert the model to SQL, and avoid SQL injection""" + _foverrides = field_overrides or {} + + # check for bad field_overrides + bad_field_overrides = set(_foverrides.keys()) - set( + f.name for f in dataclasses.fields(self) + ) + if bad_field_overrides: + raise ValueError( + f'Specified field overrides that were not used: {bad_field_overrides}' + ) + + fields = dataclasses.fields(self) + conditionals, values = [], {} + for field in fields: + fcolumn = _foverrides.get(field.name, field.name) + if filter_ := getattr(self, field.name): + if isinstance(filter_, dict): + meta_conditionals, meta_values = prepare_bq_query_from_dict_field( + filter_=filter_, field_name=field.name, column_name=fcolumn + ) + conditionals.extend(meta_conditionals) + values.update(meta_values) + elif isinstance(filter_, GenericBQFilter): + fconditionals, fvalues = filter_.to_sql(fcolumn) + conditionals.append(fconditionals) + values.update(fvalues) + else: + raise ValueError( + f'Filter {field.name} must be a GenericBQFilter or ' + 'dict[str, GenericBQFilter]' + ) + + if not conditionals: + return 'True', {} + + return ' AND '.join(filter(None, conditionals)), values diff --git a/models/models/__init__.py b/models/models/__init__.py index 0582ddfbb..d3b836e9a 100644 --- a/models/models/__init__.py +++ b/models/models/__init__.py @@ -15,6 +15,10 @@ BillingColumn, BillingCostBudgetRecord, BillingCostDetailsRecord, + BillingHailBatchCostRecord, + BillingInternal, + BillingSource, + BillingTimeColumn, BillingTimePeriods, BillingTotalCostQueryModel, BillingTotalCostRecord, diff --git a/models/models/billing.py b/models/models/billing.py index 830d4554d..05d062519 100644 --- a/models/models/billing.py +++ b/models/models/billing.py @@ -1,10 +1,41 @@ import datetime from enum import Enum +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter from models.base import SMBase from models.enums.billing import BillingSource, BillingTimeColumn, BillingTimePeriods +class BillingInternal(SMBase): + """Model for Analysis""" + + id: str | None + ar_guid: str | None + gcp_project: str | None + topic: str | None + batch_id: str | None + cost_category: str | None + cost: float | None + day: datetime.date | None + + @staticmethod + def from_db(**kwargs): + """ + Convert from db keys, mainly converting id to id_ + """ + return BillingInternal( + id=kwargs.get('id'), + ar_guid=kwargs.get('ar_guid', kwargs.get('ar-guid')), + gcp_project=kwargs.get('gcp_project'), + topic=kwargs.get('topic'), + batch_id=kwargs.get('batch_id'), + cost_category=kwargs.get('cost_category'), + cost=kwargs.get('cost'), + day=kwargs.get('day'), + ) + + class BillingColumn(str, Enum): """List of billing columns""" @@ -95,8 +126,10 @@ def is_extended_column(cls, value: 'BillingColumn') -> bool: @classmethod def str_to_enum(cls, value: str) -> 'BillingColumn': """Convert string to enum""" + # all column names have underscore in SQL, but dash in UI / stored data + adjusted_value = value.replace('-', '_') str_to_enum = {v.value: v for k, v in BillingColumn.__members__.items()} - return str_to_enum[value] + return str_to_enum[adjusted_value] @classmethod def raw_cols(cls) -> list[str]: @@ -201,6 +234,19 @@ def __hash__(self): """Create hash for this object to use in caching""" return hash(self.json()) + def to_filter(self) -> BillingFilter: + """ + Convert to internal analysis filter + """ + billing_filter = BillingFilter() + if self.filters: + # add filters as attributes + for fk, fv in self.filters.items(): + # fk is BillColumn, fv is value + setattr(billing_filter, fk.value, GenericBQFilter(eq=fv)) + + return billing_filter + class BillingTotalCostRecord(SMBase): """Return class for the Billing Total Cost record""" diff --git a/web/src/pages/billing/BillingCostByTime.tsx b/web/src/pages/billing/BillingCostByTime.tsx index 3b6d99092..ba191d335 100644 --- a/web/src/pages/billing/BillingCostByTime.tsx +++ b/web/src/pages/billing/BillingCostByTime.tsx @@ -305,9 +305,15 @@ const BillingCostByTime: React.FunctionComponent = () => { setIsLoading(false) setError(undefined) - if (start !== undefined || start !== null || start !== '') { + if (groupBy === undefined || groupBy === null) { + // Group By not selected + setMessage('Please select Group By') + } else if (selectedData === undefined || selectedData === null || selectedData === '') { + // Top Level not selected + setMessage(`Please select ${groupBy}`) + } else if (start === undefined || start === null || start === '') { setMessage('Please select Start date') - } else if (end !== undefined || end !== null || end !== '') { + } else if (end === undefined || end === null || end === '') { setMessage('Please select End date') } else { // generic message diff --git a/web/src/shared/components/Graphs/HorizontalStackedBarChart.tsx b/web/src/shared/components/Graphs/HorizontalStackedBarChart.tsx index e317311d8..204d7cc7a 100644 --- a/web/src/shared/components/Graphs/HorizontalStackedBarChart.tsx +++ b/web/src/shared/components/Graphs/HorizontalStackedBarChart.tsx @@ -31,7 +31,7 @@ const HorizontalStackedBarChart: React.FC = ({ isLoading, showLegend, }) => { - if (!data || data.length === 0) { + if (!isLoading && (!data || data.length === 0)) { return
No data available
}