diff --git a/api/routes/billing.py b/api/routes/billing.py index 7d93a599c..3a830197d 100644 --- a/api/routes/billing.py +++ b/api/routes/billing.py @@ -1,17 +1,20 @@ """ Billing routes """ + from async_lru import alru_cache from fastapi import APIRouter +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse from api.settings import BILLING_CACHE_RESPONSE_TTL, BQ_AGGREG_VIEW from api.utils.db import BqConnection, get_author from db.python.layers.billing import BillingLayer from models.enums import BillingSource from models.models import ( + BillingBatchCostRecord, BillingColumn, BillingCostBudgetRecord, - BillingHailBatchCostRecord, BillingTotalCostQueryModel, BillingTotalCostRecord, ) @@ -273,34 +276,38 @@ async def get_namespaces( @router.get( '/cost-by-ar-guid/{ar_guid}', - response_model=BillingHailBatchCostRecord, + response_model=list[BillingBatchCostRecord], operation_id='costByArGuid', ) @alru_cache(maxsize=10, ttl=BILLING_CACHE_RESPONSE_TTL) async def get_cost_by_ar_guid( ar_guid: str, author: str = get_author, -) -> BillingHailBatchCostRecord: +) -> JSONResponse: """Get Hail Batch costs by AR GUID""" billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_cost_by_ar_guid(ar_guid) - return records + headers = {'x-bq-cost': str(billing_layer.connection.cost)} + json_compatible_item_data = jsonable_encoder(records) + return JSONResponse(content=json_compatible_item_data, headers=headers) @router.get( '/cost-by-batch-id/{batch_id}', - response_model=BillingHailBatchCostRecord, + response_model=list[BillingBatchCostRecord], operation_id='costByBatchId', ) @alru_cache(maxsize=10, ttl=BILLING_CACHE_RESPONSE_TTL) async def get_cost_by_batch_id( batch_id: str, author: str = get_author, -) -> BillingHailBatchCostRecord: +) -> JSONResponse: """Get Hail Batch costs by Batch ID""" billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_cost_by_batch_id(batch_id) - return records + headers = {'x-bq-cost': str(billing_layer.connection.cost)} + json_compatible_item_data = jsonable_encoder(records) + return JSONResponse(content=json_compatible_item_data, headers=headers) @router.post( diff --git a/api/settings.py b/api/settings.py index 98d11062f..f1d4c1c77 100644 --- a/api/settings.py +++ b/api/settings.py @@ -44,6 +44,9 @@ BQ_GCP_BILLING_VIEW = os.getenv('SM_GCP_BQ_BILLING_VIEW') BQ_BATCHES_VIEW = os.getenv('SM_GCP_BQ_BATCHES_VIEW') +# BQ cost per 1 TB, used to calculate cost of BQ queries +BQ_COST_PER_TB = 6.25 + # This is to optimise BQ queries, DEV table has data only for Mar 2023 BQ_DAYS_BACK_OPTIMAL = 30 # Look back 30 days for optimal query BILLING_CACHE_RESPONSE_TTL = 3600 # 1 Hour diff --git a/db/python/gcp_connect.py b/db/python/gcp_connect.py index 72fbc97df..5d0271e68 100644 --- a/db/python/gcp_connect.py +++ b/db/python/gcp_connect.py @@ -1,6 +1,7 @@ """ Code for connecting to Big Query database """ + import logging import os @@ -23,6 +24,8 @@ def __init__( self.gcp_project = os.getenv('METAMIST_GCP_PROJECT') self.connection: bq.Client = bq.Client(project=self.gcp_project) self.author: str = author + # initialise cost of the query + self._cost: float = 0 @staticmethod async def get_connection_no_project(author: str): @@ -35,6 +38,16 @@ async def get_connection_no_project(author: str): return BqConnection(author=author) + @property + def cost(self) -> float: + """Get the cost of the query""" + return self._cost + + @cost.setter + def cost(self, value: float): + """Set the cost of the query""" + self._cost = value + class BqDbBase: """Base class for big query database subclasses""" diff --git a/db/python/layers/billing.py b/db/python/layers/billing.py index 1122259ba..c92692c2c 100644 --- a/db/python/layers/billing.py +++ b/db/python/layers/billing.py @@ -4,11 +4,11 @@ from db.python.tables.bq.billing_daily_extended import BillingDailyExtendedTable from db.python.tables.bq.billing_gcp_daily import BillingGcpDailyTable from db.python.tables.bq.billing_raw import BillingRawTable -from models.enums import BillingSource, BillingTimeColumn, BillingTimePeriods +from models.enums import BillingSource from models.models import ( + BillingBatchCostRecord, BillingColumn, BillingCostBudgetRecord, - BillingHailBatchCostRecord, BillingTotalCostQueryModel, ) @@ -32,6 +32,8 @@ def table_factory( return BillingGcpDailyTable(self.connection) if source == BillingSource.RAW: return BillingRawTable(self.connection) + if source == BillingSource.EXTENDED: + return BillingDailyExtendedTable(self.connection) # check if any of the fields is in the extended columns if fields: @@ -202,7 +204,7 @@ async def get_running_cost( async def get_cost_by_ar_guid( self, ar_guid: str | None = None, - ) -> BillingHailBatchCostRecord: + ) -> list[BillingBatchCostRecord]: """ Get Costs by AR GUID """ @@ -216,44 +218,18 @@ async def get_cost_by_ar_guid( ) = await ar_batch_lookup_table.get_batches_by_ar_guid(ar_guid) if not batches: - return BillingHailBatchCostRecord( - ar_guid=ar_guid, - batch_ids=[], - costs=[], - ) + return [] - # Then get the costs for the given AR GUID/batches from the main table - all_cols = [BillingColumn.str_to_enum(v) for v in BillingColumn.raw_cols()] - - query = BillingTotalCostQueryModel( - fields=all_cols, - source=BillingSource.RAW, - start_date=start_day.strftime('%Y-%m-%d'), - end_date=end_day.strftime('%Y-%m-%d'), - filters={ - BillingColumn.LABELS: { - 'batch_id': batches, - 'ar-guid': ar_guid, - } - }, - filters_op='OR', - group_by=False, - time_column=BillingTimeColumn.USAGE_END_TIME, - time_periods=BillingTimePeriods.DAY, - ) - - billing_table = self.table_factory(query.source, query.fields) - records = await billing_table.get_total_cost(query) - return BillingHailBatchCostRecord( - ar_guid=ar_guid, - batch_ids=batches, - costs=records, + billing_table = BillingDailyExtendedTable(self.connection) + results = await billing_table.get_batch_cost_summary( + start_day, end_day, batches, ar_guid ) + return results async def get_cost_by_batch_id( self, batch_id: str | None = None, - ) -> BillingHailBatchCostRecord: + ) -> list[BillingBatchCostRecord]: """ Get Costs by Batch ID """ @@ -270,31 +246,10 @@ async def get_cost_by_batch_id( ) = await ar_batch_lookup_table.get_batches_by_ar_guid(ar_guid) if not batches: - return BillingHailBatchCostRecord(ar_guid=ar_guid, batch_ids=[], costs=[]) + return [] - # Then get the costs for the given AR GUID/batches from the main table - all_cols = [BillingColumn.str_to_enum(v) for v in BillingColumn.raw_cols()] - - query = BillingTotalCostQueryModel( - fields=all_cols, - source=BillingSource.RAW, - start_date=start_day.strftime('%Y-%m-%d'), - end_date=end_day.strftime('%Y-%m-%d'), - filters={ - BillingColumn.LABELS: { - 'batch_id': batches, - 'ar-guid': ar_guid, - } - }, - filters_op='OR', - group_by=False, - time_column=BillingTimeColumn.USAGE_END_TIME, - time_periods=BillingTimePeriods.DAY, - ) - billing_table = self.table_factory(query.source, query.fields) - records = await billing_table.get_total_cost(query) - return BillingHailBatchCostRecord( - ar_guid=ar_guid, - batch_ids=batches, - costs=records, + billing_table = BillingDailyExtendedTable(self.connection) + results = await billing_table.get_batch_cost_summary( + start_day, end_day, batches, ar_guid ) + return results diff --git a/db/python/layers/bq_base.py b/db/python/layers/bq_base.py index f1993060a..8897dacd2 100644 --- a/db/python/layers/bq_base.py +++ b/db/python/layers/bq_base.py @@ -11,3 +11,8 @@ def __init__(self, connection: BqConnection): def author(self): """Get author from connection""" return self.connection.author + + @property + def cost(self): + """Get author from connection""" + return self.connection.cost diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py index 95837e0f3..1b5645d7a 100644 --- a/db/python/tables/assay.py +++ b/db/python/tables/assay.py @@ -197,7 +197,8 @@ async def get_assay_type_numbers_by_batch_for_project(self, project: ProjectId): """ rows = await self.connection.fetch_all(_query, {'project': project}) batch_result: dict[str, dict[str, str]] = defaultdict(dict) - for batch, seqType, count in rows: + for row in rows: + batch, seqType, count = row['batch'], row['type'], row['n'] batch = str(batch).strip('\"') if batch != 'null' else 'no-batch' batch_result[batch][seqType] = str(count) if len(batch_result) == 1 and 'no-batch' in batch_result: diff --git a/db/python/tables/bq/billing_base.py b/db/python/tables/bq/billing_base.py index 350dfecab..bdbdb14b1 100644 --- a/db/python/tables/bq/billing_base.py +++ b/db/python/tables/bq/billing_base.py @@ -6,7 +6,7 @@ from google.cloud import bigquery -from api.settings import BQ_BUDGET_VIEW, BQ_DAYS_BACK_OPTIMAL +from api.settings import BQ_BUDGET_VIEW, BQ_COST_PER_TB, BQ_DAYS_BACK_OPTIMAL from api.utils.dates import get_invoice_month_range, reformat_datetime from db.python.gcp_connect import BqDbBase from db.python.tables.bq.billing_filter import BillingFilter @@ -99,8 +99,10 @@ def get_table_name(self): raise NotImplementedError('Calling Abstract method directly') def _execute_query( - self, query: str, params: list[Any] = None, results_as_list: bool = True - ) -> list[Any]: + self, query: str, params: list[Any] | None = None, results_as_list: bool = True + ) -> ( + list[Any] | bigquery.table.RowIterator | bigquery.table._EmptyRowIterator | None + ): """Execute query, add BQ labels""" if params: job_config = bigquery.QueryJobConfig( @@ -109,13 +111,31 @@ def _execute_query( else: job_config = bigquery.QueryJobConfig(labels=BQ_LABELS) + # We need to dry run to calulate the costs + # executing query does not provide the cost + # more info here: + # https://stackoverflow.com/questions/58561153/what-is-the-python-api-i-can-use-to-calculate-the-cost-of-a-bigquery-query/58561358#58561358 + job_config.dry_run = True + job_config.use_query_cache = False + query_job = self._connection.connection.query(query, job_config=job_config) + + # This should be thread/async safe as each request + # creates a new connection instance + # and queries per requests are run in sequencial order, + # waiting for the previous one to finish + self._connection.cost += ( + query_job.total_bytes_processed / 1024**4 + ) * BQ_COST_PER_TB + + # now execute the query + job_config.dry_run = False + job_config.use_query_cache = True + query_job = self._connection.connection.query(query, job_config=job_config) if results_as_list: - return list( - self._connection.connection.query(query, job_config=job_config).result() - ) + return list(query_job.result()) # otherwise return as BQ iterator - return self._connection.connection.query(query, job_config=job_config) + return query_job @staticmethod def _query_to_partitioned_filter( @@ -129,12 +149,16 @@ def _query_to_partitioned_filter( # initial partition filter billing_filter.day = GenericBQFilter[datetime]( - gte=datetime.strptime(query.start_date, '%Y-%m-%d') - if query.start_date - else None, - lte=datetime.strptime(query.end_date, '%Y-%m-%d') - if query.end_date - else None, + gte=( + datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None + ), + lte=( + datetime.strptime(query.end_date, '%Y-%m-%d') + if query.end_date + else None + ), ) return billing_filter @@ -362,17 +386,19 @@ async def _append_total_running_cost( total_monthly=( total_monthly[COMPUTE]['ALL'] + total_monthly[STORAGE]['ALL'] ), - total_daily=(total_daily[COMPUTE]['ALL'] + total_daily[STORAGE]['ALL']) - if is_current_month - else None, + total_daily=( + (total_daily[COMPUTE]['ALL'] + total_daily[STORAGE]['ALL']) + if is_current_month + else None + ), compute_monthly=total_monthly[COMPUTE]['ALL'], - compute_daily=(total_daily[COMPUTE]['ALL']) - if is_current_month - else None, + compute_daily=( + (total_daily[COMPUTE]['ALL']) if is_current_month else None + ), storage_monthly=total_monthly[STORAGE]['ALL'], - storage_daily=(total_daily[STORAGE]['ALL']) - if is_current_month - else None, + storage_daily=( + (total_daily[STORAGE]['ALL']) if is_current_month else None + ), details=all_details, budget_spent=None, budget=None, @@ -422,17 +448,19 @@ async def _append_running_cost_records( { 'field': key, 'total_monthly': monthly, - 'total_daily': (compute_daily + storage_daily) - if is_current_month - else None, + 'total_daily': ( + (compute_daily + storage_daily) + if is_current_month + else None + ), 'compute_monthly': compute_monthly, 'compute_daily': compute_daily, 'storage_monthly': storage_monthly, 'storage_daily': storage_daily, 'details': details, - 'budget_spent': 100 * monthly / budget_monthly - if budget_monthly - else None, + 'budget_spent': ( + 100 * monthly / budget_monthly if budget_monthly else None + ), 'budget': budget_monthly, 'last_loaded_day': last_loaded_day, } diff --git a/db/python/tables/bq/billing_daily_extended.py b/db/python/tables/bq/billing_daily_extended.py index 009144911..5e2d2bf3c 100644 --- a/db/python/tables/bq/billing_daily_extended.py +++ b/db/python/tables/bq/billing_daily_extended.py @@ -1,9 +1,13 @@ +from datetime import datetime + +from google.cloud import bigquery + from api.settings import BQ_AGGREG_EXT_VIEW from db.python.tables.bq.billing_base import ( BillingBaseTable, time_optimisation_parameter, ) -from models.models import BillingColumn +from models.models import BillingBatchCostRecord, BillingColumn class BillingDailyExtendedTable(BillingBaseTable): @@ -49,3 +53,317 @@ async def get_extended_values(self, field: str): # return empty list if no record found return [] + + async def get_batch_cost_summary( + self, + start_time: datetime, + end_time: datetime, + batch_ids: list[str] | None, + ar_guid: str | None, + ) -> list[BillingBatchCostRecord]: + """ + Get summary of AR run + """ + query_parameters = [ + bigquery.ScalarQueryParameter('start_time', 'TIMESTAMP', start_time), + bigquery.ScalarQueryParameter('end_time', 'TIMESTAMP', end_time), + ] + + if ar_guid: + condition = 'ar_guid = @ar_guid' + query_parameters.append( + bigquery.ScalarQueryParameter('ar_guid', 'STRING', ar_guid) + ) + elif batch_ids: + condition = 'batch_id in UNNEST(@batch_ids)' + query_parameters.append( + bigquery.ArrayQueryParameter('batch_ids', 'STRING', batch_ids) + ) + else: + raise ValueError('Either ar_guid or batch_ids must be provided') + + # The query will locate all the records for the given time range + # and batch_ids or ar_guid + # getting all records in one go to limit number of queries, + # cutting the total cost of BQ + # After getting all the needed data, + # aggregate: + # the cost per topic, + # get counts of batches, + # cost per hail, cromwell and dataproc + # start and end date of the usage + # total cost, + # including cost by sku + + _query = f""" + -- get all data we need for aggregations in one go + WITH d AS ( + SELECT topic, ar_guid, + CASE WHEN compute_category IS NULL AND batch_id IS NOT NULL THEN + 'hail batch' ELSE compute_category END AS category, + batch_id, CAST(job_id AS INT) job_id, + sku, cost, + usage_start_time, usage_end_time, sequencing_group, stage, wdl_task_name, cromwell_sub_workflow_name, cromwell_workflow_id, + JSON_VALUE(PARSE_JSON(labels), '$.batch_name') as batch_name, + JSON_VALUE(PARSE_JSON(labels), '$.job_name') as job_name + FROM `{self.table_name}` + WHERE day >= @start_time + AND day <= @end_time + AND {condition} + ) + -- sku per ar-guid record + , arsku AS ( + SELECT sku, SUM(cost) as cost + FROM d + WHERE cost != 0 + GROUP BY sku + ORDER BY cost DESC + ) + , arskuacc AS ( + SELECT ARRAY_AGG(STRUCT(sku, cost)) AS skus, + FROM arsku + ) + -- sequencing group record + ,seqgrp AS ( + select sequencing_group, stage, sum(cost) as cost from d group by stage, sequencing_group + ) + ,seqgrpacc AS ( + SELECT ARRAY_AGG(STRUCT(sequencing_group, stage, cost)) as seq_groups + FROM seqgrp + ) + -- topics record + , t3 AS ( + SELECT topic, SUM(cost) AS cost + FROM d + WHERE topic IS NOT NULL + GROUP BY topic + ORDER BY cost DESC + ) + , t3acc AS ( + SELECT ARRAY_AGG(STRUCT(topic, cost)) AS topics, + FROM t3 + ) + -- category specific record + , cc AS ( + SELECT category, sum(cost) AS cost, + CASE WHEN category = 'hail batch' THEN COUNT(DISTINCT batch_id) ELSE NULL END as workflows + FROM d + GROUP BY category + ) + , ccacc AS ( + SELECT ARRAY_AGG(STRUCT(cc.category, cc.cost, cc.workflows)) AS categories + FROM cc + ) + -- total record per ar-guid + ,t AS ( + SELECT STRUCT( + ar_guid, sum(d.cost) AS cost, + MIN(usage_start_time) AS usage_start_time, + max(usage_end_time) AS usage_end_time + ) AS total, + FROM d + GROUP BY ar_guid + ) + -- hail batch job record + , jsku AS ( + SELECT batch_id, job_id, sku, SUM(cost) as cost + FROM d + WHERE cost != 0 + AND batch_id IS NOT NULL + AND job_id IS NOT NULL + GROUP BY batch_id, job_id, sku + ORDER BY batch_id, job_id, cost desc + ) + , jskuacc AS ( + SELECT batch_id, job_id, ARRAY_AGG(STRUCT(sku, cost)) as skus + FROM jsku + GROUP BY batch_id, job_id + ) + ,j AS ( + SELECT d.batch_id, d.job_id, d.job_name, sum(d.cost) AS cost, + MIN(usage_start_time) AS usage_start_time, + max(usage_end_time) AS usage_end_time + FROM d + WHERE d.batch_id IS NOT NULL AND d.job_id IS NOT NULL + GROUP BY d.batch_id, d.job_id, d.job_name + ORDER BY job_id + ) + , jacc AS ( + SELECT j.batch_id, ARRAY_AGG(STRUCT(j.job_id, j.batch_id, job_name, cost, usage_start_time, usage_end_time, jskuacc.skus)) AS jobs + from j INNER JOIN jskuacc on jskuacc.batch_id = j.batch_id AND jskuacc.job_id = j.job_id + GROUP BY j.batch_id + ) + -- hail batch record + , bsku AS ( + SELECT batch_id, sku, SUM(cost) as cost + FROM d + WHERE cost != 0 + AND batch_id IS NOT NULL + GROUP BY batch_id, sku + ORDER BY batch_id, cost desc + ) + , bskuacc AS ( + SELECT batch_id, ARRAY_AGG(STRUCT(sku, cost)) as skus + FROM bsku + GROUP BY batch_id + ) + ,b AS ( + SELECT d.batch_id, d.batch_name, + sum(d.cost) AS cost, + MIN(d.usage_start_time) AS usage_start_time, + max(d.usage_end_time) AS usage_end_time, + COUNT(DISTINCT d.job_id) as jobs_cnt + FROM d + WHERE d.batch_id IS NOT NULL + GROUP BY batch_id, batch_name + ORDER BY batch_id + ) + ,bseqgrp AS ( + select batch_id, sequencing_group, stage, sum(cost) as cost from d group by batch_id, sequencing_group, stage + ) + ,bseqgrpacc AS ( + SELECT batch_id, ARRAY_AGG(STRUCT(sequencing_group, stage, cost)) as seq_groups + FROM bseqgrp + GROUP BY batch_id + ) + ,bacc as ( + SELECT ARRAY_AGG(STRUCT(b.batch_id, batch_name, cost, usage_start_time, usage_end_time, jobs_cnt, bskuacc.skus, jacc.jobs, bseqgrpacc.seq_groups)) AS batches + from b + INNER JOIN bskuacc on bskuacc.batch_id = b.batch_id + INNER JOIN jacc ON jacc.batch_id = b.batch_id + INNER JOIN bseqgrpacc ON bseqgrpacc.batch_id = b.batch_id + ) + -- cromwell specific records: + -- wdl_task record + , wdlsku AS ( + SELECT wdl_task_name, sku, SUM(cost) as cost + FROM d + WHERE cost != 0 + AND wdl_task_name IS NOT NULL + GROUP BY wdl_task_name, sku + ORDER BY wdl_task_name, cost desc + ) + , wdlskuacc AS ( + SELECT wdl_task_name, ARRAY_AGG(STRUCT(sku, cost)) as skus + FROM wdlsku + GROUP BY wdl_task_name + ) + , wdl AS ( + SELECT d.wdl_task_name, + sum(d.cost) AS cost, + MIN(d.usage_start_time) AS usage_start_time, + max(d.usage_end_time) AS usage_end_time, + COUNT(DISTINCT d.job_id) as jobs_cnt + FROM d + WHERE d.wdl_task_name IS NOT NULL + GROUP BY wdl_task_name + ORDER BY wdl_task_name + ) + ,wdlacc as ( + SELECT ARRAY_AGG(STRUCT(wdl.wdl_task_name, cost, usage_start_time, usage_end_time, wdlskuacc.skus)) AS wdl_tasks + from wdl + INNER JOIN wdlskuacc on wdlskuacc.wdl_task_name = wdl.wdl_task_name + ) + -- cromwell_workflow record + , cwisku AS ( + SELECT cromwell_workflow_id, sku, SUM(cost) as cost + FROM d + WHERE cost != 0 + AND cromwell_workflow_id IS NOT NULL + GROUP BY cromwell_workflow_id, sku + ORDER BY cromwell_workflow_id, cost desc + ) + , cwiskuacc AS ( + SELECT cromwell_workflow_id, ARRAY_AGG(STRUCT(sku, cost)) as skus + FROM cwisku + GROUP BY cromwell_workflow_id + ) + , cwi AS ( + SELECT d.cromwell_workflow_id, + sum(d.cost) AS cost, + MIN(d.usage_start_time) AS usage_start_time, + max(d.usage_end_time) AS usage_end_time, + COUNT(DISTINCT d.job_id) as jobs_cnt + FROM d + WHERE d.cromwell_workflow_id IS NOT NULL + GROUP BY cromwell_workflow_id + ORDER BY cromwell_workflow_id + ) + ,cwiacc as ( + SELECT ARRAY_AGG(STRUCT(cwi.cromwell_workflow_id, cost, usage_start_time, usage_end_time, cwiskuacc.skus)) AS cromwell_workflows + from cwi + INNER JOIN cwiskuacc on cwiskuacc.cromwell_workflow_id = cwi.cromwell_workflow_id + ) + -- cromwell_sub_workflow record + , cswsku AS ( + SELECT cromwell_sub_workflow_name, sku, SUM(cost) as cost + FROM d + WHERE cost != 0 + AND cromwell_sub_workflow_name IS NOT NULL + GROUP BY cromwell_sub_workflow_name, sku + ORDER BY cromwell_sub_workflow_name, cost desc + ) + , cswskuacc AS ( + SELECT cromwell_sub_workflow_name, ARRAY_AGG(STRUCT(sku, cost)) as skus + FROM cswsku + GROUP BY cromwell_sub_workflow_name + ) + , csw AS ( + SELECT d.cromwell_sub_workflow_name, + sum(d.cost) AS cost, + MIN(d.usage_start_time) AS usage_start_time, + max(d.usage_end_time) AS usage_end_time, + COUNT(DISTINCT d.job_id) as jobs_cnt + FROM d + WHERE d.cromwell_sub_workflow_name IS NOT NULL + GROUP BY cromwell_sub_workflow_name + ORDER BY cromwell_sub_workflow_name + ) + ,cswacc as ( + SELECT ARRAY_AGG(STRUCT(csw.cromwell_sub_workflow_name, cost, usage_start_time, usage_end_time, cswskuacc.skus)) AS cromwell_sub_workflows + from csw + INNER JOIN cswskuacc on cswskuacc.cromwell_sub_workflow_name = csw.cromwell_sub_workflow_name + ) + -- dataproc specific record + , dprocsku AS ( + SELECT category as dataproc, sku, SUM(cost) as cost + FROM d + WHERE category = 'dataproc' + AND cost != 0 + GROUP BY dataproc, sku + ORDER BY cost DESC + ) + , dprocskuacc AS ( + SELECT dataproc, ARRAY_AGG(STRUCT(sku, cost)) AS skus, + FROM dprocsku + GROUP BY dataproc + ) + ,dproc AS ( + SELECT category as dataproc, sum(cost) AS cost, + MIN(usage_start_time) AS usage_start_time, + max(usage_end_time) AS usage_end_time + FROM d + WHERE category = 'dataproc' + GROUP BY dataproc + ) + , dprocacc as ( + SELECT ARRAY_AGG(STRUCT(dproc.dataproc, cost, usage_start_time, usage_end_time, dprocskuacc.skus)) AS dataproc + from dproc + INNER JOIN dprocskuacc on dprocskuacc.dataproc = dproc.dataproc + ) + -- merge all in one record + SELECT t.total, t3acc.topics, ccacc.categories, bacc.batches, arskuacc.skus, + wdlacc.wdl_tasks, cswacc.cromwell_sub_workflows, cwiacc.cromwell_workflows, seqgrpacc.seq_groups, dprocacc.dataproc + FROM t, t3acc, ccacc, bacc, arskuacc, wdlacc, cswacc, cwiacc, seqgrpacc, dprocacc + + """ + + query_job_result = self._execute_query(_query, query_parameters, False) + + if query_job_result: + return [ + BillingBatchCostRecord.from_json(dict(row)) for row in query_job_result + ] + + # return empty list if no record found + return [] diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py index 8aab115a8..e782038ab 100644 --- a/db/python/tables/family_participant.py +++ b/db/python/tables/family_participant.py @@ -20,7 +20,7 @@ async def create_row( paternal_id: int, maternal_id: int, affected: int, - notes: str = None, + notes: str | None = None, ) -> Tuple[int, int]: """ Create a new sample, and add it to database @@ -111,8 +111,8 @@ async def get_rows( keys = [ 'fp.family_id', 'p.id as individual_id', - 'fp.paternal_participant_id', - 'fp.maternal_participant_id', + 'fp.paternal_participant_id as paternal_id', + 'fp.maternal_participant_id as maternal_id', 'p.reported_sex as sex', 'fp.affected', ] @@ -153,7 +153,7 @@ async def get_rows( 'sex', 'affected', ] - ds = [dict(zip(ordered_keys, row)) for row in rows] + ds = [{k: row[k] for k in ordered_keys} for row in rows] return ds @@ -161,7 +161,7 @@ async def get_row( self, family_id: int, participant_id: int, - ): + ) -> dict | None: """Get a single row from the family_participant table""" values: Dict[str, Any] = { 'family_id': family_id, @@ -169,7 +169,13 @@ async def get_row( } _query = """ -SELECT fp.family_id, p.id as individual_id, fp.paternal_participant_id, fp.maternal_participant_id, p.reported_sex as sex, fp.affected +SELECT + fp.family_id as family_id, + p.id as individual_id, + fp.paternal_participant_id as paternal_id, + fp.maternal_participant_id as maternal_id, + p.reported_sex as sex, + fp.affected FROM family_participant fp INNER JOIN family f ON f.id = fp.family_id INNER JOIN participant p on fp.participant_id = p.id @@ -177,6 +183,8 @@ async def get_row( """ row = await self.connection.fetch_one(_query, values) + if not row: + return None ordered_keys = [ 'family_id', @@ -186,7 +194,7 @@ async def get_row( 'sex', 'affected', ] - ds = dict(zip(ordered_keys, row)) + ds = {k: row[k] for k in ordered_keys} return ds diff --git a/db/python/tables/participant_phenotype.py b/db/python/tables/participant_phenotype.py index ea010588d..d50b5bbb7 100644 --- a/db/python/tables/participant_phenotype.py +++ b/db/python/tables/participant_phenotype.py @@ -44,7 +44,7 @@ async def add_key_value_rows(self, rows: List[Tuple[int, str, Any]]) -> None: ) async def get_key_value_rows_for_participant_ids( - self, participant_ids=List[int] + self, participant_ids: List[int] ) -> Dict[int, Dict[str, Any]]: """ Get (participant_id, description, value), @@ -64,7 +64,9 @@ async def get_key_value_rows_for_participant_ids( ) formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict) for row in rows: - pid, key, value = row + pid = row['participant_id'] + key = row['description'] + value = row['value'] formed_key_value_pairs[pid][key] = json.loads(value) return formed_key_value_pairs @@ -86,7 +88,9 @@ async def get_key_value_rows_for_all_participants( rows = await self.connection.fetch_all(_query, {'project': project}) formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict) for row in rows: - pid, key, value = row + pid = row['participant_id'] + key = row['description'] + value = row['value'] formed_key_value_pairs[pid][key] = json.loads(value) return formed_key_value_pairs diff --git a/etl/load/main.py b/etl/load/main.py index 65e7a2dc6..28f3b291d 100644 --- a/etl/load/main.py +++ b/etl/load/main.py @@ -1,6 +1,7 @@ import asyncio import base64 import datetime +import importlib.metadata import json import logging import os @@ -10,7 +11,6 @@ import flask import functions_framework import google.cloud.bigquery as bq -import pkg_resources from google.cloud import pubsub_v1, secretmanager from metamist.parser.generic_parser import GenericParser # type: ignore @@ -25,14 +25,23 @@ @lru_cache def _get_bq_client(): + assert BIGQUERY_TABLE, 'BIGQUERY_TABLE is not set' + assert BIGQUERY_LOG_TABLE, 'BIGQUERY_LOG_TABLE is not set' return bq.Client() @lru_cache def _get_secret_manager(): + assert ETL_ACCESSOR_CONFIG_SECRET, 'CONFIGURATION_SECRET is not set' return secretmanager.SecretManagerServiceClient() +@lru_cache +def _get_pubsub_client(): + assert NOTIFICATION_PUBSUB_TOPIC, 'NOTIFICATION_PUBSUB_TOPIC is not set' + return pubsub_v1.PublisherClient() + + class ParsingStatus: """ Enum type to distinguish between sucess and failure of parsing @@ -65,7 +74,7 @@ def call_parser(parser_obj, row_json) -> tuple[str, str]: async def run_parser_capture_result(parser_obj, row_data, res, status): try: # TODO better error handling - r = await parser_obj.from_json([row_data], confirm=False, dry_run=True) + r = await parser_obj.from_json(row_data, confirm=False) res.append(r) status.append(ParsingStatus.SUCCESS) except Exception as e: # pylint: disable=broad-exception-caught @@ -148,7 +157,7 @@ def process_rows( # publish to notification pubsub msg_title = 'Metamist ETL Load Failed' try: - pubsub_client = pubsub_v1.PublisherClient() + pubsub_client = _get_pubsub_client() pubsub_client.publish( NOTIFICATION_PUBSUB_TOPIC, json.dumps({'title': msg_title} | log_record).encode(), @@ -212,6 +221,15 @@ def etl_load(request: flask.Request): 'message': f'Missing or empty request_id: {jbody_str}', }, 400 + return process_request(request_id, delivery_attempt) + + +def process_request( + request_id: str, delivery_attempt: int | None = None +) -> tuple[dict, int]: + """ + Process request_id, delivery_attempt and return result + """ # locate the request_id in bq query = f""" SELECT * FROM `{BIGQUERY_TABLE}` WHERE request_id = @request_id @@ -322,13 +340,16 @@ def get_parser_instance( accessor_config: dict[ str, - list[ - dict[ - Literal['name'] - | Literal['parser_name'] - | Literal['default_parameters'], - Any, - ] + dict[ + Literal['parsers'], + list[ + dict[ + Literal['name'] + | Literal['parser_name'] + | Literal['default_parameters'], + Any, + ] + ], ], ] = get_accessor_config() @@ -341,7 +362,7 @@ def get_parser_instance( etl_accessor_config = next( ( accessor_config - for accessor_config in accessor_config[submitting_user] + for accessor_config in accessor_config[submitting_user].get('parsers', []) if accessor_config['name'].strip(STRIP_CHARS) == request_type.strip(STRIP_CHARS) ), @@ -387,7 +408,8 @@ def prepare_parser_map() -> dict[str, type[GenericParser]]: loop through metamist_parser entry points and create map of parsers """ parser_map = {} - for entry_point in pkg_resources.iter_entry_points('metamist_parser'): + + for entry_point in importlib.metadata.entry_points().get('metamist_parser'): parser_cls = entry_point.load() parser_short_name, parser_version = parser_cls.get_info() parser_map[f'{parser_short_name}/{parser_version}'] = parser_cls diff --git a/etl/test/test_etl_load.py b/etl/test/test_etl_load.py index 203bc9f33..a1ccbb4c1 100644 --- a/etl/test/test_etl_load.py +++ b/etl/test/test_etl_load.py @@ -200,13 +200,15 @@ def test_get_parser_instance_success( """Test get_parser_instance success""" mock_get_accessor_config.return_value = { - 'user@test.com': [ - { - 'name': 'test/v1', - 'default_parameters': {}, - # 'parser_name': 'test', - } - ] + 'user@test.com': { + 'parsers': [ + { + 'name': 'test/v1', + 'default_parameters': {}, + # 'parser_name': 'test', + } + ] + } } mock_prepare_parser_map.return_value = { @@ -234,13 +236,15 @@ def test_get_parser_instance_different_parser_name( """Test get_parser_instance success""" mock_get_accessor_config.return_value = { - 'user@test.com': [ - { - 'name': 'test/v1', - 'default_parameters': {'project': 'test'}, - 'parser_name': 'different_parser/name', - } - ] + 'user@test.com': { + 'parsers': [ + { + 'name': 'test/v1', + 'default_parameters': {'project': 'test'}, + 'parser_name': 'different_parser/name', + } + ] + } } mock_prepare_parser_map.return_value = { @@ -284,12 +288,14 @@ def test_get_parser_no_matching_config( """Test get_parser_instance success""" mock_get_accessor_config.return_value = { - 'user@test.com': [ - { - 'name': 'test/v1', - 'default_parameters': {'project': 'test'}, - } - ] + 'user@test.com': { + 'parsers': [ + { + 'name': 'test/v1', + 'default_parameters': {'project': 'test'}, + } + ] + } } # this doesn't need to be mocked as it fails before here @@ -316,12 +322,14 @@ def test_get_parser_no_matching_parser( """Test get_parser_instance success""" mock_get_accessor_config.return_value = { - 'user@test.com': [ - { - 'name': 'a/b', - 'default_parameters': {'project': 'test'}, - } - ] + 'user@test.com': { + 'parsers': [ + { + 'name': 'a/b', + 'default_parameters': {'project': 'test'}, + } + ] + } } mock_prepare_parser_map.return_value = { diff --git a/metamist/parser/generic_parser.py b/metamist/parser/generic_parser.py index 8dc2de5d7..fbe95136b 100644 --- a/metamist/parser/generic_parser.py +++ b/metamist/parser/generic_parser.py @@ -534,6 +534,9 @@ async def from_json(self, rows, confirm=False, dry_run=False): If no participants are present, groups samples by their IDs. For each sample, gets its sequencing groups by their keys. For each sequencing group, groups assays and analyses. """ + if not isinstance(rows, list): + rows = [rows] + await self.validate_rows(rows) # one participant with no value diff --git a/models/models/__init__.py b/models/models/__init__.py index ce5868cb1..8b100662f 100644 --- a/models/models/__init__.py +++ b/models/models/__init__.py @@ -12,10 +12,10 @@ from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal from models.models.audit_log import AuditLogId, AuditLogInternal from models.models.billing import ( + BillingBatchCostRecord, BillingColumn, BillingCostBudgetRecord, BillingCostDetailsRecord, - BillingHailBatchCostRecord, BillingInternal, BillingTotalCostQueryModel, BillingTotalCostRecord, diff --git a/models/models/billing.py b/models/models/billing.py index 6b5a98fbf..83cb7586f 100644 --- a/models/models/billing.py +++ b/models/models/billing.py @@ -135,25 +135,13 @@ def str_to_enum(cls, value: str) -> 'BillingColumn': def raw_cols(cls) -> list[str]: """Return list of raw column names""" return [ - BillingColumn.ID.value, BillingColumn.TOPIC.value, - BillingColumn.SERVICE.value, BillingColumn.SKU.value, BillingColumn.USAGE_START_TIME.value, BillingColumn.USAGE_END_TIME.value, - BillingColumn.PROJECT.value, BillingColumn.LABELS.value, - BillingColumn.SYSTEM_LABELS.value, - BillingColumn.LOCATION.value, - BillingColumn.EXPORT_TIME.value, BillingColumn.COST.value, BillingColumn.CURRENCY.value, - BillingColumn.CURRENCY_CONVERSION_RATE.value, - BillingColumn.USAGE.value, - BillingColumn.CREDITS.value, - BillingColumn.INVOICE.value, - BillingColumn.COST_TYPE.value, - BillingColumn.ADJUSTMENT_INFO.value, ] @classmethod @@ -362,9 +350,33 @@ def from_json(record): ) -class BillingHailBatchCostRecord(SMBase): +class BillingBatchCostRecord(SMBase): """Return class for the Billing Cost by batch_id/ar_guid""" - ar_guid: str | None - batch_ids: list[str] | None - costs: list[dict] | None + total: dict | None + topics: list[dict] | None + categories: list[dict] | None + batches: list[dict] | None + skus: list[dict] | None + seq_groups: list[dict] | None + + wdl_tasks: list[dict] | None + cromwell_sub_workflows: list[dict] | None + cromwell_workflows: list[dict] | None + dataproc: list[dict] | None + + @staticmethod + def from_json(record): + """Create BillingBatchCostRecord from json""" + return BillingBatchCostRecord( + total=record.get('total'), + topics=record.get('topics'), + categories=record.get('categories'), + batches=record.get('batches'), + skus=record.get('skus'), + seq_groups=record.get('seq_groups'), + wdl_tasks=record.get('wdl_tasks'), + cromwell_sub_workflows=record.get('cromwell_sub_workflows'), + cromwell_workflows=record.get('cromwell_workflows'), + dataproc=record.get('dataproc'), + ) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3bb86aa87..c73f19f70 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ flake8-bugbear nest-asyncio pre-commit pylint -testcontainers[mariadb]==3.7.1 +testcontainers[mariadb]>=4.0.0 types-PyMySQL # some strawberry dependency strawberry-graphql[debug-server]==0.206.0 diff --git a/requirements.txt b/requirements.txt index 1f9640ea2..172fa8fe0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,8 +17,8 @@ uvicorn==0.18.3 fastapi[all]==0.85.1 strawberry-graphql[fastapi]==0.206.0 python-multipart==0.0.5 -databases[mysql]==0.6.1 -SQLAlchemy==1.4.41 +databases[mysql]==0.9.0 +SQLAlchemy==2.0.28 cryptography>=41.0.0 python-dateutil==2.8.2 slack-sdk==3.20.2 diff --git a/test/test_api_billing.py b/test/test_api_billing.py index 4811187a7..2f5c00665 100644 --- a/test/test_api_billing.py +++ b/test/test_api_billing.py @@ -1,13 +1,14 @@ # pylint: disable=protected-access too-many-public-methods +import json from test.testbase import run_as_sync from test.testbqbase import BqTest from unittest.mock import patch from api.routes import billing from models.models import ( + BillingBatchCostRecord, BillingColumn, BillingCostBudgetRecord, - BillingHailBatchCostRecord, BillingTotalCostQueryModel, BillingTotalCostRecord, ) @@ -54,15 +55,28 @@ async def test_get_cost_by_ar_guid( Test get_cost_by_ar_guid function """ ar_guid = 'test_ar_guid' - mockup_record = BillingHailBatchCostRecord( - ar_guid=ar_guid, batch_ids=None, costs=None - ) + mockup_record_json = { + 'total': {'ar_guid': ar_guid}, + 'topics': None, + 'categories': None, + 'batches': [], + 'skus': None, + 'seq_groups': None, + 'wdl_tasks': None, + 'cromwell_sub_workflows': None, + 'cromwell_workflows': None, + 'dataproc': None, + } + + mockup_record = [BillingBatchCostRecord.from_json(mockup_record_json)] mock_get_billing_layer.return_value = self.layer mock_get_cost_by_ar_guid.return_value = mockup_record - records = await billing.get_cost_by_ar_guid( + response = await billing.get_cost_by_ar_guid( ar_guid, author=TEST_API_BILLING_USER ) - self.assertEqual(mockup_record, records) + self.assertEqual( + [mockup_record_json], json.loads(response.body.decode('utf-8')) + ) @run_as_sync @patch('api.routes.billing._get_billing_layer_from') @@ -75,15 +89,28 @@ async def test_get_cost_by_batch_id( """ ar_guid = 'test_ar_guid' batch_id = 'test_batch_id' - mockup_record = BillingHailBatchCostRecord( - ar_guid=ar_guid, batch_ids=[batch_id], costs=None - ) + mockup_record_json = { + 'total': {'ar_guid': ar_guid}, + 'topics': None, + 'categories': None, + 'batches': [{'batch_id': batch_id}], + 'skus': None, + 'seq_groups': None, + 'wdl_tasks': None, + 'cromwell_sub_workflows': None, + 'cromwell_workflows': None, + 'dataproc': None, + } + + mockup_record = [BillingBatchCostRecord.from_json(mockup_record_json)] mock_get_billing_layer.return_value = self.layer mock_get_cost_by_batch_id.return_value = mockup_record - records = await billing.get_cost_by_batch_id( + response = await billing.get_cost_by_batch_id( batch_id, author=TEST_API_BILLING_USER ) - self.assertEqual(mockup_record, records) + self.assertEqual( + [mockup_record_json], json.loads(response.body.decode('utf-8')) + ) @run_as_sync @patch('api.routes.billing._get_billing_layer_from') diff --git a/test/test_bq_billing_base.py b/test/test_bq_billing_base.py index 655b40486..609983790 100644 --- a/test/test_bq_billing_base.py +++ b/test/test_bq_billing_base.py @@ -362,9 +362,10 @@ def test_execute_query_results_not_as_list(self): given_bq_results = [[], [123], ['a', 'b', 'c']] for bq_result in given_bq_results: # mock BigQuery result - self.bq_client.query.return_value = bq_result + self.bq_result.result.return_value = bq_result + self.bq_result.total_bytes_processed = 0 results = self.table_obj._execute_query( - sql_query, sql_params, results_as_list=False + sql_query, sql_params, results_as_list=True ) self.assertEqual(bq_result, results) @@ -380,9 +381,10 @@ def test_execute_query_with_sql_params(self): given_bq_results = [[], [123], ['a', 'b', 'c']] for bq_result in given_bq_results: # mock BigQuery result - self.bq_client.query.return_value = bq_result + self.bq_result.result.return_value = bq_result + self.bq_result.total_bytes_processed = 0 results = self.table_obj._execute_query( - sql_query, sql_params, results_as_list=False + sql_query, sql_params, results_as_list=True ) self.assertEqual(bq_result, results) diff --git a/test/test_layers_billing.py b/test/test_layers_billing.py index be13f3fa3..535d374ad 100644 --- a/test/test_layers_billing.py +++ b/test/test_layers_billing.py @@ -8,11 +8,7 @@ from db.python.layers.billing import BillingLayer from models.enums import BillingSource -from models.models import ( - BillingColumn, - BillingHailBatchCostRecord, - BillingTotalCostQueryModel, -) +from models.models import BillingColumn, BillingTotalCostQueryModel class TestBillingLayer(BqTest): @@ -359,7 +355,11 @@ async def test_get_running_cost(self): @run_as_sync async def test_get_cost_by_ar_guid(self): - """Test get_cost_by_ar_guid""" + """ + Test get_cost_by_ar_guid + This test only paths in the layer, + the logic and processing is tested in test/test_bq_billing_base.py + """ layer = BillingLayer(self.connection) @@ -367,9 +367,7 @@ async def test_get_cost_by_ar_guid(self): records = await layer.get_cost_by_ar_guid(ar_guid=None) # return empty record - self.assertEqual( - BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records - ) + self.assertEqual([], records) # dummy ar_guid, no mockup data, return empty results dummy_ar_guid = '12345678' @@ -377,13 +375,11 @@ async def test_get_cost_by_ar_guid(self): # return empty record self.assertEqual( - BillingHailBatchCostRecord(ar_guid=dummy_ar_guid, batch_ids=[], costs=[]), + [], records, ) - # dummy ar_guid, mockup batch_id - - # mock BigQuery result + # mock BigQuery first query result given_start_day = datetime.datetime(2023, 1, 1, 0, 0) given_end_day = datetime.datetime(2023, 1, 1, 2, 3) dummy_batch_id = '12345' @@ -401,19 +397,21 @@ async def test_get_cost_by_ar_guid(self): self.bq_result.result.return_value = mock_rows records = await layer.get_cost_by_ar_guid(ar_guid=dummy_ar_guid) - # returns ar_guid, batch_id and empty cost as those were not mocked up + # returns empty list as those were not mocked up # we do not need to test cost calculation here, # as those are tested in test/test_bq_billing_base.py self.assertEqual( - BillingHailBatchCostRecord( - ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}] - ), + [], records, ) @run_as_sync async def test_get_cost_by_batch_id(self): - """Test get_cost_by_batch_id""" + """ + Test get_cost_by_batch_id + This test only paths in the layer, + the logic and processing is tested in test/test_bq_billing_base.py + """ layer = BillingLayer(self.connection) @@ -421,9 +419,7 @@ async def test_get_cost_by_batch_id(self): records = await layer.get_cost_by_batch_id(batch_id=None) # return empty record - self.assertEqual( - BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records - ) + self.assertEqual([], records) # dummy ar_guid, no mockup data, return empty results dummy_batch_id = '12345' @@ -431,7 +427,7 @@ async def test_get_cost_by_batch_id(self): # return empty record self.assertEqual( - BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), + [], records, ) @@ -459,12 +455,10 @@ async def test_get_cost_by_batch_id(self): self.bq_result.result.return_value = mock_rows records = await layer.get_cost_by_batch_id(batch_id=dummy_batch_id) - # returns ar_guid, batch_id and empty cost as those were not mocked up + # returns elmpty list as those were not mocked up # we do not need to test cost calculation here, # as those are tested in test/test_bq_billing_base.py self.assertEqual( - BillingHailBatchCostRecord( - ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}] - ), + [], records, ) diff --git a/test/test_web.py b/test/test_web.py index cc5a0bb63..cc7ffad99 100644 --- a/test/test_web.py +++ b/test/test_web.py @@ -221,7 +221,7 @@ async def test_project_summary_empty(self): seqr_sync_types=[], ) - self.assertEqual(expected, result) + self.assertDataclassEqual(expected, result) @run_as_sync async def test_project_summary_single_entry(self): @@ -232,7 +232,7 @@ async def test_project_summary_single_entry(self): result = await self.webl.get_project_summary(token=0, grid_filter=[]) result.participants = [] - self.assertEqual(SINGLE_PARTICIPANT_RESULT, result) + self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, result) @run_as_sync async def test_project_summary_to_external(self): @@ -289,7 +289,7 @@ async def project_summary_with_filter_with_results(self): ], ) filtered_result_success.participants = [] - self.assertEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success) + self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success) @run_as_sync async def project_summary_with_filter_no_results(self): @@ -323,7 +323,7 @@ async def project_summary_with_filter_no_results(self): seqr_sync_types=[], ) - self.assertEqual(empty_result, filtered_result_empty) + self.assertDataclassEqual(empty_result, filtered_result_empty) @run_as_sync async def test_project_summary_multiple_participants(self): @@ -376,7 +376,7 @@ async def test_project_summary_multiple_participants(self): two_samples_result.participants = [] - self.assertEqual(expected_data_two_samples, two_samples_result) + self.assertDataclassEqual(expected_data_two_samples, two_samples_result) @run_as_sync async def test_project_summary_multiple_participants_and_filter(self): @@ -436,7 +436,7 @@ async def test_project_summary_multiple_participants_and_filter(self): ) two_samples_result_filtered.participants = [] - self.assertEqual( + self.assertDataclassEqual( expected_data_two_samples_filtered, two_samples_result_filtered ) @@ -499,7 +499,9 @@ async def test_field_with_space(self): seqr_sync_types=[], ) - self.assertEqual(expected_data_two_samples_filtered, test_field_with_space) + self.assertDataclassEqual( + expected_data_two_samples_filtered, test_field_with_space + ) @run_as_sync async def test_project_summary_inactive_sequencing_group(self): diff --git a/test/testbase.py b/test/testbase.py index 04c8e1d2d..fa60ee66f 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-overridden-method import asyncio +import dataclasses import logging import os import socket @@ -96,10 +97,10 @@ async def setup(): logger = logging.getLogger() try: set_all_access(True) - db = MySqlContainer('mariadb:10.8.3') + db = MySqlContainer('mariadb:11.2.2') port_to_expose = find_free_port() # override the default port to map the container to - db.with_bind_ports(db.port_to_expose, port_to_expose) + db.with_bind_ports(db.port, port_to_expose) logger.disabled = True db.start() logger.disabled = False @@ -111,7 +112,7 @@ async def setup(): con_string = db.get_connection_url() con_string = 'mysql://' + con_string.split('://', maxsplit=1)[1] - lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.MYSQL_DATABASE}' + lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.dbname}' # apply the liquibase schema command = [ 'liquibase', @@ -120,8 +121,8 @@ async def setup(): *('--url', lcon_string), *('--driver', 'org.mariadb.jdbc.Driver'), *('--classpath', db_prefix + '/mariadb-java-client-3.0.3.jar'), - *('--username', db.MYSQL_USER), - *('--password', db.MYSQL_PASSWORD), + *('--username', db.username), + *('--password', db.password), 'update', ] subprocess.check_output(command, stderr=subprocess.STDOUT) @@ -175,7 +176,7 @@ async def setup(): def tearDownClass(cls) -> None: db = cls.dbs.get(cls.__name__) if db: - db.exec(f'DROP DATABASE {db.MYSQL_DATABASE};') + db.exec(f'DROP DATABASE {db.dbname};') db.stop() def setUp(self) -> None: @@ -224,6 +225,19 @@ async def audit_log_id(self): """Get audit_log_id for the test""" return await self.connection.audit_log_id() + def assertDataclassEqual(self, a, b): + """Assert two dataclasses are equal""" + + def to_dict(obj): + d = dataclasses.asdict(obj) + for k, v in d.items(): + if dataclasses.is_dataclass(v): + d[k] = to_dict(v) + return d + + self.maxDiff = None + self.assertDictEqual(to_dict(a), to_dict(b)) + class DbIsolatedTest(DbTest): """ diff --git a/web/package-lock.json b/web/package-lock.json index 62926084c..f04895d4d 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "metamist", - "version": "6.4.0", + "version": "6.6.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "metamist", - "version": "6.4.0", + "version": "6.6.2", "dependencies": { "@apollo/client": "^3.7.3", "@artsy/fresnel": "^6.2.1", @@ -35,6 +35,7 @@ "react-responsive": "^9.0.2", "react-router-dom": "^6.0.1", "react-syntax-highlighter": "^15.4.4", + "react-virtuoso": "^4.6.3", "recharts": "^2.8.0", "remark-gfm": "^3.0.1", "remark-toc": "^8.0.1", @@ -10171,6 +10172,18 @@ "react-dom": ">=16.6.0" } }, + "node_modules/react-virtuoso": { + "version": "4.6.3", + "resolved": "https://registry.npmjs.org/react-virtuoso/-/react-virtuoso-4.6.3.tgz", + "integrity": "sha512-NcoSsf4B0OCx7U8i2s+VWe8b9e+FWzcN/5ly4hKjErynBzGONbWORZ1C5amUlWrPi6+HbUQ2PjnT4OpyQIpP9A==", + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "react": ">=16 || >=17 || >= 18", + "react-dom": ">=16 || >=17 || >= 18" + } + }, "node_modules/readable-stream": { "version": "3.6.2", "dev": true, diff --git a/web/package.json b/web/package.json index 852999cee..1c1527f30 100644 --- a/web/package.json +++ b/web/package.json @@ -30,6 +30,7 @@ "react-responsive": "^9.0.2", "react-router-dom": "^6.0.1", "react-syntax-highlighter": "^15.4.4", + "react-virtuoso": "^4.6.3", "recharts": "^2.8.0", "remark-gfm": "^3.0.1", "remark-toc": "^8.0.1", diff --git a/web/src/pages/billing/BillingCostByAnalysis.tsx b/web/src/pages/billing/BillingCostByAnalysis.tsx index 87e1ef153..d18b5d58d 100644 --- a/web/src/pages/billing/BillingCostByAnalysis.tsx +++ b/web/src/pages/billing/BillingCostByAnalysis.tsx @@ -4,10 +4,12 @@ import { Button, Card, Grid, Input, Message, Select, Dropdown } from 'semantic-u import SearchIcon from '@mui/icons-material/Search' import LoadingDucks from '../../shared/components/LoadingDucks/LoadingDucks' -import { BillingApi, BillingTotalCostRecord } from '../../sm-api' -import HailBatchGrid from './components/HailBatchGrid' +import { BillingApi, BillingTotalCostRecord, AnalysisApi } from '../../sm-api' +import BatchGrid from './components/BatchGrid' + import { getMonthStartDate } from '../../shared/utilities/monthStartEndDate' import generateUrl from '../../shared/utilities/generateUrl' +import { List } from 'lodash' enum SearchType { Ar_guid, @@ -27,6 +29,32 @@ const BillingCostByAnalysis: React.FunctionComponent = () => { const [data, setData] = React.useState(undefined) + const setArData = (arData: []) => { + setIsLoading(false) + // arData is an array of objects, we use only the first obejct + // in the future we maye have search by several ar_guids / author etc. + if (arData === undefined || arData.length === 0) { + // nothing found + setIsLoading(false) + return + } + const ar_record = arData[0] + if (!!ar_record?.total?.ar_guid) { + new AnalysisApi() + .getAnalysisRunnerLog(undefined, undefined, ar_record.total.ar_guid, undefined) + .then((response) => { + // combine arData and getAnalysisRunnerLog + if (response.data.length > 0) { + // use only the first record for now + ar_record.analysisRunnerLog = response.data[0] + } + setData(ar_record) + }) + .catch((er) => setError(er.message)) + } + setIsLoading(false) + } + const [searchTxt, setSearchTxt] = React.useState(searchParams.get('searchTxt') ?? '') const searchOptions: string[] = Object.keys(SearchType).filter((item) => isNaN(Number(item))) @@ -51,7 +79,7 @@ const BillingCostByAnalysis: React.FunctionComponent = () => { navigate(url) } - const getData = (sType: SearchType | undefined | string, sTxt: string) => { + const getArData = (sType: SearchType | undefined | string, sTxt: string) => { if ((sType === undefined || sTxt === undefined) && sTxt.length < 6) { // Seaarch text is not large enough setIsLoading(false) @@ -67,16 +95,14 @@ const BillingCostByAnalysis: React.FunctionComponent = () => { new BillingApi() .costByArGuid(sTxt) .then((response) => { - setIsLoading(false) - setData(response.data) + setArData(response.data) }) .catch((er) => setError(er.message)) } else if (convertedType === SearchType.Batch_id) { new BillingApi() .costByBatchId(sTxt) .then((response) => { - setIsLoading(false) - setData(response.data) + setArData(response.data) }) .catch((er) => setError(er.message)) } else { @@ -86,11 +112,11 @@ const BillingCostByAnalysis: React.FunctionComponent = () => { const handleSearch = () => { if (searchByType === undefined || searchTxt === undefined || searchTxt.length < 6) { - // Seaarch text is not large enough + // Search text is not large enough setIsLoading(false) return } - getData(searchByType, searchTxt) + getArData(searchByType, searchTxt) } const handleSearchChange = (event: any, dt: any) => { @@ -193,21 +219,23 @@ const BillingCostByAnalysis: React.FunctionComponent = () => {
Ar guid: f5a065d2-c51f-46b7-a920-a89b639fc4ba
- Batch id: 430604, 430605 + Batch id: 430604 +
+ Hail Batch + DataProc: 433599 +
+ Cromwell: ec3f961f-7e16-4fb0-a3e3-9fc93006ab42 +
+ Large Hail Batch with 7.5K jobs: a449eea5-7150-441a-9ffe-bd71587c3fe2

) - const gridCard = (gridData: BillingTotalCostRecord[]) => ( - - - - ) + const batchGrid = (gridData: BillingTotalCostRecord) => const dataComponent = () => { - if (data !== undefined && data.costs.length > 0) { + if (data !== undefined) { // only render grid if there are available cost data - return gridCard(data.costs) + return batchGrid(data) } // if valid search text and no data return return No data message diff --git a/web/src/pages/billing/components/BatchGrid.tsx b/web/src/pages/billing/components/BatchGrid.tsx new file mode 100644 index 000000000..def276078 --- /dev/null +++ b/web/src/pages/billing/components/BatchGrid.tsx @@ -0,0 +1,573 @@ +import * as React from 'react' +import { Table as SUITable, Card, Checkbox } from 'semantic-ui-react' +import _ from 'lodash' +import { DonutChart } from '../../../shared/components/Graphs/DonutChart' +import '../../project/AnalysisRunnerView/AnalysisGrid.css' +import { TableVirtuoso } from 'react-virtuoso' + +import Table from '@mui/material/Table' +import TableBody from '@mui/material/TableBody' +import TableContainer from '@mui/material/TableContainer' +import TableHead from '@mui/material/TableHead' +import TableRow from '@mui/material/TableRow' +import Paper from '@mui/material/Paper' +import formatMoney from '../../../shared/utilities/formatMoney' + +const hailBatchUrl = 'https://batch.hail.populationgenomics.org.au/batches' + +const BatchGrid: React.FunctionComponent<{ + data: any +}> = ({ data }) => { + const [openRows, setOpenRows] = React.useState([]) + + const handleToggle = (position: string) => { + if (!openRows.includes(position)) { + setOpenRows([...openRows, position]) + } else { + setOpenRows(openRows.filter((value) => value !== position)) + } + } + + const prepareBatchUrl = (batch_id: string) => ( + + BATCH ID: {batch_id} + + ) + + const prepareBgColor = (log: any) => { + if (log.batch_id === undefined) { + return 'var(--color-border-color)' + } + if (log.job_id === undefined) { + return 'var(--color-border-default)' + } + return 'var(--color-bg)' + } + + const calcDuration = (dataItem) => { + const duration = new Date(dataItem.usage_end_time) - new Date(dataItem.usage_start_time) + const seconds = Math.floor((duration / 1000) % 60) + const minutes = Math.floor((duration / (1000 * 60)) % 60) + const hours = Math.floor((duration / (1000 * 60 * 60)) % 24) + const formattedDuration = `${hours}h ${minutes}m ${seconds}s` + return {formattedDuration} + } + + const idx = 0 + + const displayCheckBoxRow = ( + parentToggle: string, + key: string, + toggle: string, + text: string + ) => ( + + + + handleToggle(toggle)} + /> + + {text} + + ) + + const displayTopLevelCheckBoxRow = (key: string, text: string) => ( + + + handleToggle(key)} + /> + + {text} + + ) + + const displayRow = (toggle: string, key: string, label: string, text: string) => ( + + + + {label} + + {text} + + ) + + const displayCostBySkuRow = ( + parentToggles: list, + toggle: string, + chartId: string, + chartMaxWidth: number, + colSpan: number, + data: any + ) => ( + <> + openRows.includes(p)) && + openRows.includes(toggle) + ? 'table-row' + : 'none', + backgroundColor: 'var(--color-bg)', + }} + key={toggle} + > + + + + {chartId && ( + ({ + label: srec.sku, + value: srec.cost, + }))} + maxSlices={data.skus.length} + showLegend={false} + isLoading={false} + maxWidth={chartMaxWidth} + /> + )} + + + + SKU + COST + + + + {data.skus.map((srec, sidx) => ( + + {srec.sku} + {formatMoney(srec.cost, 4)} + + ))} + + + + + + ) + + const displayCostBySeqGrpRow = ( + parentToggle: string, + key: string, + toggle: string, + textCheckbox: string, + data: any + ) => ( + <> + {displayCheckBoxRow(parentToggle, key, toggle, textCheckbox)} + + + + + + + + SEQ GROUP + STAGE + COST + + + + {data.seq_groups + .sort((a, b) => b.cost - a.cost) // Sort by cost in descending order + .map((gcat, gidx) => ( + + {gcat.sequencing_group} + {gcat.stage} + {formatMoney(gcat.cost, 4)} + + ))} + + + + + + ) + + const displayCommonSection = (key: string, header: string, data: any) => ( + <> + {displayTopLevelCheckBoxRow(`row-${key}`, `${header}`)} + + {displayRow( + '', + `${key}-detail-cost`, + 'Cost', + `${formatMoney(data.cost, 4)} ${ + data.jobs_cnt > 0 ? ` (across ${data.jobs_cnt} jobs)` : '' + }` + )} + + {displayRow(`row-${key}`, `${key}-detail-start`, 'Start', data.usage_start_time)} + {displayRow(`row-${key}`, `${key}-detail-end`, 'End', data.usage_end_time)} + + {displayCheckBoxRow(`row-${key}`, `sku-toggle-${key}`, `sku-${key}`, 'Cost By SKU')} + {displayCostBySkuRow([`row-${key}`], `sku-${key}`, `donut-chart-${key}`, 600, 1, data)} + + ) + + const ExpandableRow = ({ item, ...props }) => { + const index = props['data-index'] + return ( + + + + handleToggle(`${item.batch_id}-${item.job_id}`)} + /> + + {item.job_id} + {item.job_name} + {item.usage_start_time} + {calcDuration(item)} + {formatMoney(item.cost, 4)} + + + {/* cost by SKU */} + {displayCostBySkuRow( + [`row-${item.batch_id}`, `jobs-${item.batch_id}`], + `${item.batch_id}-${item.job_id}`, + undefined, + undefined, + 4, + item + )} + + ) + } + + const TableComponents = { + Scroller: React.forwardRef((props, ref) => ( + + )), + Table: (props) => , + TableHead: TableHead, + TableRow: ExpandableRow, + TableBody: React.forwardRef((props, ref) => ), + } + + const displayJobsTable = (item) => ( + 1 ? 800 : 400, backgroundColor: 'var(--color-bg)' }} + className="ui celled table compact" + useWindowScroll={false} + data={item.jobs.sort((a, b) => { + // Sorts an array of objects first by 'job_id' in ascending order. + if (a.job_id < b.job_id) { + return -1 + } + if (a.job_id > b.job_id) { + return 1 + } + return 0 + })} + fixedHeaderContent={() => ( + + + JOB ID + NAME + START + DURATION + COST + + )} + components={TableComponents} + /> + ) + + const arGuidCard = (idx, data) => ( + + + + <> + {displayTopLevelCheckBoxRow(`row-${idx}`, `AR-GUID: ${data.total.ar_guid}`)} + + {displayRow( + '', + `${idx}-detail-cost`, + 'Total cost', + formatMoney(data.total.cost, 2) + )} + + {/* cost by categories */} + {data.categories.map((tcat, cidx) => { + const workflows = + tcat.workflows !== null + ? ` (across ${tcat.workflows} workflows)` + : '' + return displayRow( + '', + `categories-${idx}-${cidx}`, + tcat.category, + `${formatMoney(tcat.cost, 2)} ${workflows}` + ) + })} + + {displayRow( + '', + `${idx}-detail-start`, + 'Start', + data.total.usage_start_time + )} + {displayRow('', `${idx}-detail-end`, 'End', data.total.usage_end_time)} + + {/* all meta if present */} + {data.analysisRunnerLog && + Object.keys(data.analysisRunnerLog.meta).map((key) => { + const mcat = data.analysisRunnerLog.meta[key] + return displayRow(`row-${idx}`, `${idx}-meta-${key}`, key, mcat) + })} + + {/* cost by topics */} + {displayCheckBoxRow( + `row-${idx}`, + `topics-toggle-${idx}`, + `topics-${idx}`, + 'Cost By Topic' + )} + + + + + + + + Topic + Cost + + + + {data.topics.map((trec, tidx) => ( + + {trec.topic} + + {formatMoney(trec.cost, 2)} + + + ))} + + + + + + {/* cost by seq groups */} + {displayCostBySeqGrpRow( + `row-${idx}`, + `seq-grp-toggle-${idx}`, + `seq-grp-${idx}`, + 'Cost By Sequencing Group', + data + )} + + {/* cost by SKU */} + {displayCheckBoxRow( + `row-${idx}`, + `sku-toggle-${idx}`, + `sku-${idx}`, + 'Cost By SKU' + )} + {displayCostBySkuRow( + [`row-${idx}`], + `sku-${idx}`, + 'total-donut-chart', + 600, + 1, + data + )} + + + + + ) + + const batchCard = (item) => ( + + + + {displayTopLevelCheckBoxRow( + `row-${item.batch_id}`, + prepareBatchUrl(item.batch_id) + )} + + {displayRow('', `${item.batch_id}-detail-name`, 'Batch Name', item.batch_name)} + + {item.jobs_cnt === 1 + ? displayRow( + '', + `${item.batch_id}-detail-job-name`, + 'Job Name', + item.jobs[0].job_name + ) + : null} + + {displayRow( + '', + `${item.batch_id}-detail-cost`, + 'Cost', + `${formatMoney(item.cost, 4)} ${ + item.jobs_cnt !== null ? ` (across ${item.jobs_cnt} jobs)` : '' + }` + )} + + {displayRow( + `row-${item.batch_id}`, + `${item.batch_id}-detail-start`, + 'Start', + data.total.usage_start_time + )} + {displayRow( + `row-${item.batch_id}`, + `${item.batch_id}-detail-end`, + 'End', + data.total.usage_end_time + )} + + {/* cost by seq groups */} + {displayCostBySeqGrpRow( + `row-${item.batch_id}`, + `seq-grp-toggle-${item.batch_id}`, + `seq-grp-${item.batch_id}`, + 'Cost By Sequencing Group', + item + )} + + {/* cost by SKU */} + {displayCheckBoxRow( + `row-${item.batch_id}`, + `sku-toggle-${item.batch_id}`, + `sku-${item.batch_id}`, + 'Cost By SKU' + )} + {displayCostBySkuRow( + [`row-${item.batch_id}`], + `sku-${item.batch_id}`, + `donut-chart-${item.batch_id}`, + 600, + 1, + item + )} + + {/* cost by jobs */} + {item.jobs_cnt > 1 && ( + <> + {displayCheckBoxRow( + `row-${item.batch_id}`, + `jobs-toggle-${item.batch_id}`, + `jobs-${item.batch_id}`, + 'Cost By JOBS' + )} + + + + {displayJobsTable(item)} + + + )} + + + + ) + + const genericCard = (item, data, label) => ( + + + {displayCommonSection(data, label, item)} + + + ) + + return ( + <> + {arGuidCard(idx, data)} + + {data.batches.map((item) => batchCard(item))} + + {data.dataproc.map((item) => genericCard(item, item.dataproc, `DATAPROC`))} + + {data.wdl_tasks.map((item) => + genericCard(item, item.wdl_task_name, `WDL TASK NAME: ${item.wdl_task_name}`) + )} + + {data.cromwell_sub_workflows.map((item) => + genericCard( + item, + item.cromwell_sub_workflow_name, + `CROMWELL SUB WORKFLOW NAME: ${item.cromwell_sub_workflow_name}` + ) + )} + + {data.cromwell_workflows.map((item) => + genericCard( + item, + item.cromwell_workflow_id, + `CROMWELL WORKFLOW ID: ${item.cromwell_workflow_id}` + ) + )} + + ) +} + +export default BatchGrid diff --git a/web/src/pages/billing/components/HailBatchGrid.tsx b/web/src/pages/billing/components/HailBatchGrid.tsx deleted file mode 100644 index d3e8e199c..000000000 --- a/web/src/pages/billing/components/HailBatchGrid.tsx +++ /dev/null @@ -1,499 +0,0 @@ -import * as React from 'react' -import { Table as SUITable, Popup, Checkbox } from 'semantic-ui-react' -import _ from 'lodash' -import Table from '../../../shared/components/Table' -import sanitiseValue from '../../../shared/utilities/sanitiseValue' -import '../../project/AnalysisRunnerView/AnalysisGrid.css' - -interface Field { - category: string - title: string - width?: string - className?: string - dataMap?: (data: any, value: string) => any -} - -const HailBatchGrid: React.FunctionComponent<{ - data: any[] -}> = ({ data }) => { - // prepare aggregated data by ar_guid, batch_id, job_id and coresponding batch_resource - const aggArGUIDData: any[] = [] - data.forEach((curr) => { - const { cost, topic, usage_start_time, usage_end_time } = curr - const ar_guid = curr['ar-guid'] - const usageStartDate = new Date(usage_start_time) - const usageEndDate = new Date(usage_end_time) - const idx = aggArGUIDData.findIndex((d) => d.ar_guid === ar_guid && d.topic === topic) - if (cost >= 0) { - // do not include credits, should be filter out at API? - if (idx === -1) { - aggArGUIDData.push({ - type: 'ar_guid', - key: ar_guid, - ar_guid, - batch_id: undefined, - job_id: undefined, - topic, - cost, - start_time: usageStartDate, - end_time: usageEndDate, - }) - } else { - aggArGUIDData[idx].cost += cost - aggArGUIDData[idx].start_time = new Date( - Math.min(usageStartDate.getTime(), aggArGUIDData[idx].start_time.getTime()) - ) - aggArGUIDData[idx].end_time = new Date( - Math.max(usageEndDate.getTime(), aggArGUIDData[idx].end_time.getTime()) - ) - } - } - }) - - const aggArGUIDResource: any[] = [] - data.forEach((curr) => { - const { cost, batch_resource } = curr - const ar_guid = curr['ar-guid'] - const idx = aggArGUIDResource.findIndex( - (d) => d.ar_guid === ar_guid && d.batch_resource === batch_resource - ) - if (cost >= 0) { - // do not include credits, should be filter out at API? - if (idx === -1) { - aggArGUIDResource.push({ - type: 'ar_guid', - key: ar_guid, - ar_guid, - batch_resource, - cost, - }) - } else { - aggArGUIDResource[idx].cost += cost - } - } - }) - const aggBatchData: any[] = [] - data.forEach((curr) => { - const { - batch_id, - url, - topic, - namespace, - batch_name, - cost, - usage_start_time, - usage_end_time, - } = curr - const ar_guid = curr['ar-guid'] - const usageStartDate = new Date(usage_start_time) - const usageEndDate = new Date(usage_end_time) - const idx = aggBatchData.findIndex( - (d) => - d.batch_id === batch_id && - d.batch_name === batch_name && - d.topic === topic && - d.namespace === namespace - ) - if (cost >= 0) { - // do not include credits, should be filter out at API? - if (idx === -1) { - aggBatchData.push({ - type: 'batch_id', - key: batch_id, - ar_guid, - batch_id, - url, - topic, - namespace, - batch_name, - job_id: undefined, - cost, - start_time: usageStartDate, - end_time: usageEndDate, - }) - } else { - aggBatchData[idx].cost += cost - aggBatchData[idx].start_time = new Date( - Math.min(usageStartDate.getTime(), aggBatchData[idx].start_time.getTime()) - ) - aggBatchData[idx].end_time = new Date( - Math.max(usageEndDate.getTime(), aggBatchData[idx].end_time.getTime()) - ) - } - } - }) - - const aggBatchResource: any[] = [] - data.forEach((curr) => { - const { batch_id, batch_resource, topic, namespace, batch_name, cost } = curr - const ar_guid = curr['ar-guid'] - const idx = aggBatchResource.findIndex( - (d) => - d.batch_id === batch_id && - d.batch_name === batch_name && - d.batch_resource === batch_resource && - d.topic === topic && - d.namespace === namespace - ) - if (cost >= 0) { - // do not include credits, should be filter out at API? - if (idx === -1) { - aggBatchResource.push({ - type: 'batch_id', - key: batch_id, - ar_guid, - batch_id, - batch_resource, - topic, - namespace, - batch_name, - cost, - }) - } else { - aggBatchResource[idx].cost += cost - } - } - }) - - const aggBatchJobData: any[] = [] - data.forEach((curr) => { - const { batch_id, url, cost, topic, namespace, job_id, usage_start_time, usage_end_time } = - curr - const ar_guid = curr['ar-guid'] - const usageStartDate = new Date(usage_start_time) - const usageEndDate = new Date(usage_end_time) - const idx = aggBatchJobData.findIndex( - (d) => - d.batch_id === batch_id && - d.job_id === job_id && - d.topic === topic && - d.namespace === namespace - ) - if (cost >= 0) { - if (idx === -1) { - aggBatchJobData.push({ - type: 'batch_id/job_id', - key: `${batch_id}/${job_id}`, - batch_id, - job_id, - ar_guid, - url, - topic, - namespace, - cost, - start_time: usageStartDate, - end_time: usageEndDate, - }) - } else { - aggBatchJobData[idx].cost += cost - aggBatchJobData[idx].start_time = new Date( - Math.min(usageStartDate.getTime(), aggBatchJobData[idx].start_time.getTime()) - ) - aggBatchJobData[idx].end_time = new Date( - Math.max(usageEndDate.getTime(), aggBatchJobData[idx].end_time.getTime()) - ) - } - } - }) - - const aggBatchJobResource: any[] = [] - data.forEach((curr) => { - const { batch_id, batch_resource, topic, namespace, cost, job_id, job_name } = curr - const ar_guid = curr['ar-guid'] - const idx = aggBatchJobResource.findIndex( - (d) => - d.batch_id === batch_id && - d.job_id === job_id && - d.batch_resource === batch_resource && - d.topic === topic && - d.namespace === namespace - ) - if (cost >= 0) { - if (idx === -1) { - aggBatchJobResource.push({ - type: 'batch_id/job_id', - key: `${batch_id}/${job_id}`, - batch_id, - job_id, - ar_guid, - batch_resource, - topic, - namespace, - cost, - job_name, - }) - } else { - aggBatchJobResource[idx].cost += cost - } - } - }) - - const aggData = [...aggArGUIDData, ...aggBatchData, ...aggBatchJobData] - const aggResource = [...aggArGUIDResource, ...aggBatchResource, ...aggBatchJobResource] - - // combine data and resource for each ar_guid, batch_id, job_id - const combinedData = aggData.map((dataItem) => { - const details = aggResource.filter( - (resourceItem) => - resourceItem.key === dataItem.key && resourceItem.type === dataItem.type - ) - return { ...dataItem, details } - }) - - const [openRows, setOpenRows] = React.useState([]) - - const handleToggle = (position: number) => { - if (!openRows.includes(position)) { - setOpenRows([...openRows, position]) - } else { - setOpenRows(openRows.filter((i) => i !== position)) - } - } - - const prepareBatchUrl = (url: string, txt: string) => ( - - {txt} - - ) - - const prepareBgColor = (log: any) => { - if (log.batch_id === undefined) { - return 'var(--color-border-color)' - } - if (log.job_id === undefined) { - return 'var(--color-border-default)' - } - return 'var(--color-bg)' - } - - const MAIN_FIELDS: Field[] = [ - { - category: 'job_id', - title: 'ID', - dataMap: (dataItem: any, value: string) => { - if (dataItem.batch_id === undefined) { - return `AR GUID: ${dataItem.ar_guid}` - } - if (dataItem.job_id === undefined) { - return prepareBatchUrl(dataItem.url, `BATCH ID: ${dataItem.batch_id}`) - } - return prepareBatchUrl(dataItem.url, `JOB: ${value}`) - }, - }, - { - category: 'start_time', - title: 'TIME STARTED', - dataMap: (dataItem: any, value: string) => { - const dateValue = new Date(value) - return ( - - {Number.isNaN(dateValue.getTime()) ? '' : dateValue.toLocaleString()} - - ) - }, - }, - { - category: 'end_time', - title: 'TIME COMPLETED', - dataMap: (dataItem: any, value: string) => { - const dateValue = new Date(value) - return ( - - {Number.isNaN(dateValue.getTime()) ? '' : dateValue.toLocaleString()} - - ) - }, - }, - { - category: 'duration', - title: 'DURATION', - dataMap: (dataItem: any, _value: string) => { - const duration = new Date( - dataItem.end_time.getTime() - dataItem.start_time.getTime() - ) - const seconds = Math.floor((duration / 1000) % 60) - const minutes = Math.floor((duration / (1000 * 60)) % 60) - const hours = Math.floor((duration / (1000 * 60 * 60)) % 24) - const formattedDuration = `${hours}h ${minutes}m ${seconds}s` - return {formattedDuration} - }, - }, - { - category: 'cost', - title: 'COST', - dataMap: (dataItem: any, _value: string) => ( - ${dataItem.cost.toFixed(4)}} - position="top center" - /> - ), - }, - ] - - const DETAIL_FIELDS: Field[] = [ - { - category: 'topic', - title: 'TOPIC', - }, - { - category: 'namespace', - title: 'NAMESPACE', - }, - { - category: 'batch_name', - title: 'NAME/SCRIPT', - }, - { - category: 'job_name', - title: 'NAME', - }, - ] - - const expandedRow = (log: any, idx: any) => - MAIN_FIELDS.map(({ category, dataMap, className }) => ( - - {dataMap ? dataMap(log, log[category]) : sanitiseValue(log[category])} - - )) - - return ( -
- - - - {MAIN_FIELDS.map(({ category, title }, i) => ( - - {title} - - ))} - - - - {MAIN_FIELDS.map(({ category }, i) => ( - - ))} - - - - {combinedData - .sort((a, b) => { - // Sorts an array of objects first by 'batch_id' and then by 'job_id' in ascending order. - if (a.batch_id < b.batch_id) { - return -1 - } - if (a.batch_id > b.batch_id) { - return 1 - } - if (a.job_id < b.job_id) { - return -1 - } - if (a.job_id > b.job_id) { - return 1 - } - return 0 - }) - .map((log, idx) => ( - - - - handleToggle(log.key)} - /> - - {expandedRow(log, idx)} - - {Object.entries(log) - .filter(([c]) => - DETAIL_FIELDS.map(({ category }) => category).includes(c) - ) - .map(([k, v]) => { - const detailField = DETAIL_FIELDS.find( - ({ category }) => category === k - ) - const title = detailField ? detailField.title : k - return ( - - - - {title} - - {v} - - ) - })} - - - - COST BREAKDOWN - - - {typeof log === 'object' && - 'details' in log && - _.orderBy(log?.details, ['cost'], ['desc']).map((dk) => ( - - - - {dk.batch_resource} - - ${dk.cost.toFixed(4)} - - ))} - - ))} - -
- ) -} - -export default HailBatchGrid diff --git a/web/src/shared/components/Graphs/DonutChart.tsx b/web/src/shared/components/Graphs/DonutChart.tsx index 02f46292b..af1f99ba7 100644 --- a/web/src/shared/components/Graphs/DonutChart.tsx +++ b/web/src/shared/components/Graphs/DonutChart.tsx @@ -9,10 +9,14 @@ export interface IDonutChartData { } export interface IDonutChartProps { + id?: string data?: IDonutChartData[] maxSlices: number - colors?: (t: number) => string | undefined + colors?: (t: number) => string isLoading: boolean + legendSize?: number + showLegend?: boolean + maxWidth?: string } interface IDonutChartPreparadData { @@ -29,7 +33,16 @@ function calcTranslate(data: IDonutChartPreparadData, move = 4) { })` } -export const DonutChart: React.FC = ({ data, maxSlices, colors, isLoading }) => { +export const DonutChart: React.FC = ({ + id, + data, + maxSlices, + colors, + isLoading, + legendSize, + showLegend, + maxWidth, +}) => { if (isLoading) { return (
@@ -46,11 +59,14 @@ export const DonutChart: React.FC = ({ data, maxSlices, colors const duration = 250 const containerDivRef = React.useRef() const [graphWidth, setGraphWidth] = React.useState(768) + // to distinquished between charts on the same page we need an id + const chartId = id ?? 'donutChart' const onHoverOver = (tg: HTMLElement, v: IDonutChartPreparadData) => { - select(`#lbl${v.index}`).select('tspan').attr('font-weight', 'bold') - select(`#legend${v.index}`).attr('font-weight', 'bold') - select(`#lgd${v.index}`).attr('font-weight', 'bold') + select(`#${chartId}-lbl${v.index}`).select('tspan').attr('font-weight', 'bold') + select(`#${chartId}-legend${v.index}`).attr('font-weight', 'bold') + select(`#${chartId}-lgd${v.index}`).attr('font-weight', 'bold') + select(`#${chartId}-lgd${v.index}`).attr('style', 'font-weight: bold') select(tg).transition().duration(duration).attr('transform', calcTranslate(v, 6)) select(tg) .select('path') @@ -62,9 +78,10 @@ export const DonutChart: React.FC = ({ data, maxSlices, colors } const onHoverOut = (tg: HTMLElement, v: IDonutChartPreparadData) => { - select(`#lbl${v.index}`).select('tspan').attr('font-weight', 'normal') - select(`#legend${v.index}`).attr('font-weight', 'normal') - select(`#lgd${v.index}`).attr('font-weight', 'normal') + select(`#${chartId}-lbl${v.index}`).select('tspan').attr('font-weight', 'normal') + select(`#${chartId}-legend${v.index}`).attr('font-weight', 'normal') + select(`#${chartId}-lgd${v.index}`).attr('font-weight', 'normal') + select(`#${chartId}-lgd${v.index}`).attr('style', 'font-weight: normal') select(tg).transition().duration(duration).attr('transform', 'translate(0, 0)') select(tg) .select('path') @@ -74,6 +91,20 @@ export const DonutChart: React.FC = ({ data, maxSlices, colors .attr('stroke-width', 1) } + function createViewBox(legSize, w) { + // calculate the viewbox of Legend + const minX = 0 + const minY = 0 + let width = 200 + let height = 200 + if (legSize) { + width = legSize * w + height = legSize * w + } + + return `${minX} ${minY} ${width} ${height}` + } + const width = graphWidth const height = width const margin = 15 @@ -138,7 +169,7 @@ export const DonutChart: React.FC = ({ data, maxSlices, colors .style('stroke-width', '2') .style('opacity', '0.8') .style('cursor', 'pointer') - .attr('id', (d) => `path${d.index}`) + .attr('id', (d) => `${chartId}-path${d.index}`) .on('mouseover', (event, v) => { onHoverOver(event.currentTarget, v) }) @@ -159,7 +190,7 @@ export const DonutChart: React.FC = ({ data, maxSlices, colors .data(data_ready) .join('text') .attr('transform', (d) => `translate(${arcLabel.centroid(d)})`) - .attr('id', (d) => `lbl${d.index}`) + .attr('id', (d) => `${chartId}-lbl${d.index}`) .selectAll('tspan') .data((d) => { const lines = `${formatMoney(d.data.value)}`.split(/\n/) @@ -172,42 +203,44 @@ export const DonutChart: React.FC = ({ data, maxSlices, colors .text((d) => d) // add legend - const svgLegend = select(contDiv) - .append('svg') - .attr('width', '45%') - .attr('viewBox', '0 0 200 200') - .attr('vertical-align', 'top') - - svgLegend - .selectAll('g.legend') - .data(data_ready) - .enter() - .append('g') - .attr('transform', (d) => `translate(${margin},${margin + d.index * 20})`) - .each(function (d, i) { - select(this) - .append('circle') - .attr('r', 8) - .attr('fill', (d) => colorFunc(d.index / maxSlices)) - select(this) - .append('text') - .attr('text-anchor', 'start') - .attr('x', 20) - .attr('y', 0) - .attr('dy', '0.35em') - .attr('id', (d) => `legend${d.index}`) - .text(d.data.label) - .attr('font-size', '0.9em') - select(this) - .on('mouseover', (event, v) => { - const element = select(`#path${d.index}`) - onHoverOver(element.node(), d) - }) - .on('mouseout', (event, v) => { - const element = select(`#path${d.index}`) - onHoverOut(element.node(), d) - }) - }) + if (showLegend === true) { + const svgLegend = select(contDiv) + .append('svg') + .attr('width', '45%') + .attr('viewBox', createViewBox(legendSize, width)) + .attr('vertical-align', 'top') + + svgLegend + .selectAll('g.legend') + .data(data_ready) + .enter() + .append('g') + .attr('transform', (d) => `translate(${margin},${margin + d.index * 20})`) + .each(function (d, i) { + select(this) + .append('circle') + .attr('r', 8) + .attr('fill', (d) => colorFunc(d.index / maxSlices)) + select(this) + .append('text') + .attr('text-anchor', 'start') + .attr('x', 20) + .attr('y', 0) + .attr('dy', '0.35em') + .attr('id', (d) => `${chartId}-legend${d.index}`) + .text(d.data.label) + .attr('font-size', '0.9em') + select(this) + .on('mouseover', (event, v) => { + const element = select(`#${chartId}-path${d.index}`) + onHoverOver(element.node(), d) + }) + .on('mouseout', (event, v) => { + const element = select(`#${chartId}-path${d.index}`) + onHoverOut(element.node(), d) + }) + }) + } } - return
+ return
} diff --git a/web/src/shared/utilities/formatMoney.ts b/web/src/shared/utilities/formatMoney.ts index 3a270b7fc..d6039aba2 100644 --- a/web/src/shared/utilities/formatMoney.ts +++ b/web/src/shared/utilities/formatMoney.ts @@ -1,3 +1,3 @@ -const formatMoney = (val: number): string => `$${val.toFixed(2).replace(/\d(?=(\d{3})+\.)/g, '$&,')}` +const formatMoney = (val: number, dp: number = 2): string => `$${val.toFixed(dp).replace(/\d(?=(\d{3})+\.)/g, '$&,')}` export default formatMoney