diff --git a/api/routes/billing.py b/api/routes/billing.py
index 7d93a599c..3a830197d 100644
--- a/api/routes/billing.py
+++ b/api/routes/billing.py
@@ -1,17 +1,20 @@
"""
Billing routes
"""
+
from async_lru import alru_cache
from fastapi import APIRouter
+from fastapi.encoders import jsonable_encoder
+from fastapi.responses import JSONResponse
from api.settings import BILLING_CACHE_RESPONSE_TTL, BQ_AGGREG_VIEW
from api.utils.db import BqConnection, get_author
from db.python.layers.billing import BillingLayer
from models.enums import BillingSource
from models.models import (
+ BillingBatchCostRecord,
BillingColumn,
BillingCostBudgetRecord,
- BillingHailBatchCostRecord,
BillingTotalCostQueryModel,
BillingTotalCostRecord,
)
@@ -273,34 +276,38 @@ async def get_namespaces(
@router.get(
'/cost-by-ar-guid/{ar_guid}',
- response_model=BillingHailBatchCostRecord,
+ response_model=list[BillingBatchCostRecord],
operation_id='costByArGuid',
)
@alru_cache(maxsize=10, ttl=BILLING_CACHE_RESPONSE_TTL)
async def get_cost_by_ar_guid(
ar_guid: str,
author: str = get_author,
-) -> BillingHailBatchCostRecord:
+) -> JSONResponse:
"""Get Hail Batch costs by AR GUID"""
billing_layer = _get_billing_layer_from(author)
records = await billing_layer.get_cost_by_ar_guid(ar_guid)
- return records
+ headers = {'x-bq-cost': str(billing_layer.connection.cost)}
+ json_compatible_item_data = jsonable_encoder(records)
+ return JSONResponse(content=json_compatible_item_data, headers=headers)
@router.get(
'/cost-by-batch-id/{batch_id}',
- response_model=BillingHailBatchCostRecord,
+ response_model=list[BillingBatchCostRecord],
operation_id='costByBatchId',
)
@alru_cache(maxsize=10, ttl=BILLING_CACHE_RESPONSE_TTL)
async def get_cost_by_batch_id(
batch_id: str,
author: str = get_author,
-) -> BillingHailBatchCostRecord:
+) -> JSONResponse:
"""Get Hail Batch costs by Batch ID"""
billing_layer = _get_billing_layer_from(author)
records = await billing_layer.get_cost_by_batch_id(batch_id)
- return records
+ headers = {'x-bq-cost': str(billing_layer.connection.cost)}
+ json_compatible_item_data = jsonable_encoder(records)
+ return JSONResponse(content=json_compatible_item_data, headers=headers)
@router.post(
diff --git a/api/settings.py b/api/settings.py
index 98d11062f..f1d4c1c77 100644
--- a/api/settings.py
+++ b/api/settings.py
@@ -44,6 +44,9 @@
BQ_GCP_BILLING_VIEW = os.getenv('SM_GCP_BQ_BILLING_VIEW')
BQ_BATCHES_VIEW = os.getenv('SM_GCP_BQ_BATCHES_VIEW')
+# BQ cost per 1 TB, used to calculate cost of BQ queries
+BQ_COST_PER_TB = 6.25
+
# This is to optimise BQ queries, DEV table has data only for Mar 2023
BQ_DAYS_BACK_OPTIMAL = 30 # Look back 30 days for optimal query
BILLING_CACHE_RESPONSE_TTL = 3600 # 1 Hour
diff --git a/db/python/gcp_connect.py b/db/python/gcp_connect.py
index 72fbc97df..5d0271e68 100644
--- a/db/python/gcp_connect.py
+++ b/db/python/gcp_connect.py
@@ -1,6 +1,7 @@
"""
Code for connecting to Big Query database
"""
+
import logging
import os
@@ -23,6 +24,8 @@ def __init__(
self.gcp_project = os.getenv('METAMIST_GCP_PROJECT')
self.connection: bq.Client = bq.Client(project=self.gcp_project)
self.author: str = author
+ # initialise cost of the query
+ self._cost: float = 0
@staticmethod
async def get_connection_no_project(author: str):
@@ -35,6 +38,16 @@ async def get_connection_no_project(author: str):
return BqConnection(author=author)
+ @property
+ def cost(self) -> float:
+ """Get the cost of the query"""
+ return self._cost
+
+ @cost.setter
+ def cost(self, value: float):
+ """Set the cost of the query"""
+ self._cost = value
+
class BqDbBase:
"""Base class for big query database subclasses"""
diff --git a/db/python/layers/billing.py b/db/python/layers/billing.py
index 1122259ba..c92692c2c 100644
--- a/db/python/layers/billing.py
+++ b/db/python/layers/billing.py
@@ -4,11 +4,11 @@
from db.python.tables.bq.billing_daily_extended import BillingDailyExtendedTable
from db.python.tables.bq.billing_gcp_daily import BillingGcpDailyTable
from db.python.tables.bq.billing_raw import BillingRawTable
-from models.enums import BillingSource, BillingTimeColumn, BillingTimePeriods
+from models.enums import BillingSource
from models.models import (
+ BillingBatchCostRecord,
BillingColumn,
BillingCostBudgetRecord,
- BillingHailBatchCostRecord,
BillingTotalCostQueryModel,
)
@@ -32,6 +32,8 @@ def table_factory(
return BillingGcpDailyTable(self.connection)
if source == BillingSource.RAW:
return BillingRawTable(self.connection)
+ if source == BillingSource.EXTENDED:
+ return BillingDailyExtendedTable(self.connection)
# check if any of the fields is in the extended columns
if fields:
@@ -202,7 +204,7 @@ async def get_running_cost(
async def get_cost_by_ar_guid(
self,
ar_guid: str | None = None,
- ) -> BillingHailBatchCostRecord:
+ ) -> list[BillingBatchCostRecord]:
"""
Get Costs by AR GUID
"""
@@ -216,44 +218,18 @@ async def get_cost_by_ar_guid(
) = await ar_batch_lookup_table.get_batches_by_ar_guid(ar_guid)
if not batches:
- return BillingHailBatchCostRecord(
- ar_guid=ar_guid,
- batch_ids=[],
- costs=[],
- )
+ return []
- # Then get the costs for the given AR GUID/batches from the main table
- all_cols = [BillingColumn.str_to_enum(v) for v in BillingColumn.raw_cols()]
-
- query = BillingTotalCostQueryModel(
- fields=all_cols,
- source=BillingSource.RAW,
- start_date=start_day.strftime('%Y-%m-%d'),
- end_date=end_day.strftime('%Y-%m-%d'),
- filters={
- BillingColumn.LABELS: {
- 'batch_id': batches,
- 'ar-guid': ar_guid,
- }
- },
- filters_op='OR',
- group_by=False,
- time_column=BillingTimeColumn.USAGE_END_TIME,
- time_periods=BillingTimePeriods.DAY,
- )
-
- billing_table = self.table_factory(query.source, query.fields)
- records = await billing_table.get_total_cost(query)
- return BillingHailBatchCostRecord(
- ar_guid=ar_guid,
- batch_ids=batches,
- costs=records,
+ billing_table = BillingDailyExtendedTable(self.connection)
+ results = await billing_table.get_batch_cost_summary(
+ start_day, end_day, batches, ar_guid
)
+ return results
async def get_cost_by_batch_id(
self,
batch_id: str | None = None,
- ) -> BillingHailBatchCostRecord:
+ ) -> list[BillingBatchCostRecord]:
"""
Get Costs by Batch ID
"""
@@ -270,31 +246,10 @@ async def get_cost_by_batch_id(
) = await ar_batch_lookup_table.get_batches_by_ar_guid(ar_guid)
if not batches:
- return BillingHailBatchCostRecord(ar_guid=ar_guid, batch_ids=[], costs=[])
+ return []
- # Then get the costs for the given AR GUID/batches from the main table
- all_cols = [BillingColumn.str_to_enum(v) for v in BillingColumn.raw_cols()]
-
- query = BillingTotalCostQueryModel(
- fields=all_cols,
- source=BillingSource.RAW,
- start_date=start_day.strftime('%Y-%m-%d'),
- end_date=end_day.strftime('%Y-%m-%d'),
- filters={
- BillingColumn.LABELS: {
- 'batch_id': batches,
- 'ar-guid': ar_guid,
- }
- },
- filters_op='OR',
- group_by=False,
- time_column=BillingTimeColumn.USAGE_END_TIME,
- time_periods=BillingTimePeriods.DAY,
- )
- billing_table = self.table_factory(query.source, query.fields)
- records = await billing_table.get_total_cost(query)
- return BillingHailBatchCostRecord(
- ar_guid=ar_guid,
- batch_ids=batches,
- costs=records,
+ billing_table = BillingDailyExtendedTable(self.connection)
+ results = await billing_table.get_batch_cost_summary(
+ start_day, end_day, batches, ar_guid
)
+ return results
diff --git a/db/python/layers/bq_base.py b/db/python/layers/bq_base.py
index f1993060a..8897dacd2 100644
--- a/db/python/layers/bq_base.py
+++ b/db/python/layers/bq_base.py
@@ -11,3 +11,8 @@ def __init__(self, connection: BqConnection):
def author(self):
"""Get author from connection"""
return self.connection.author
+
+ @property
+ def cost(self):
+ """Get author from connection"""
+ return self.connection.cost
diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py
index 95837e0f3..1b5645d7a 100644
--- a/db/python/tables/assay.py
+++ b/db/python/tables/assay.py
@@ -197,7 +197,8 @@ async def get_assay_type_numbers_by_batch_for_project(self, project: ProjectId):
"""
rows = await self.connection.fetch_all(_query, {'project': project})
batch_result: dict[str, dict[str, str]] = defaultdict(dict)
- for batch, seqType, count in rows:
+ for row in rows:
+ batch, seqType, count = row['batch'], row['type'], row['n']
batch = str(batch).strip('\"') if batch != 'null' else 'no-batch'
batch_result[batch][seqType] = str(count)
if len(batch_result) == 1 and 'no-batch' in batch_result:
diff --git a/db/python/tables/bq/billing_base.py b/db/python/tables/bq/billing_base.py
index 350dfecab..bdbdb14b1 100644
--- a/db/python/tables/bq/billing_base.py
+++ b/db/python/tables/bq/billing_base.py
@@ -6,7 +6,7 @@
from google.cloud import bigquery
-from api.settings import BQ_BUDGET_VIEW, BQ_DAYS_BACK_OPTIMAL
+from api.settings import BQ_BUDGET_VIEW, BQ_COST_PER_TB, BQ_DAYS_BACK_OPTIMAL
from api.utils.dates import get_invoice_month_range, reformat_datetime
from db.python.gcp_connect import BqDbBase
from db.python.tables.bq.billing_filter import BillingFilter
@@ -99,8 +99,10 @@ def get_table_name(self):
raise NotImplementedError('Calling Abstract method directly')
def _execute_query(
- self, query: str, params: list[Any] = None, results_as_list: bool = True
- ) -> list[Any]:
+ self, query: str, params: list[Any] | None = None, results_as_list: bool = True
+ ) -> (
+ list[Any] | bigquery.table.RowIterator | bigquery.table._EmptyRowIterator | None
+ ):
"""Execute query, add BQ labels"""
if params:
job_config = bigquery.QueryJobConfig(
@@ -109,13 +111,31 @@ def _execute_query(
else:
job_config = bigquery.QueryJobConfig(labels=BQ_LABELS)
+ # We need to dry run to calulate the costs
+ # executing query does not provide the cost
+ # more info here:
+ # https://stackoverflow.com/questions/58561153/what-is-the-python-api-i-can-use-to-calculate-the-cost-of-a-bigquery-query/58561358#58561358
+ job_config.dry_run = True
+ job_config.use_query_cache = False
+ query_job = self._connection.connection.query(query, job_config=job_config)
+
+ # This should be thread/async safe as each request
+ # creates a new connection instance
+ # and queries per requests are run in sequencial order,
+ # waiting for the previous one to finish
+ self._connection.cost += (
+ query_job.total_bytes_processed / 1024**4
+ ) * BQ_COST_PER_TB
+
+ # now execute the query
+ job_config.dry_run = False
+ job_config.use_query_cache = True
+ query_job = self._connection.connection.query(query, job_config=job_config)
if results_as_list:
- return list(
- self._connection.connection.query(query, job_config=job_config).result()
- )
+ return list(query_job.result())
# otherwise return as BQ iterator
- return self._connection.connection.query(query, job_config=job_config)
+ return query_job
@staticmethod
def _query_to_partitioned_filter(
@@ -129,12 +149,16 @@ def _query_to_partitioned_filter(
# initial partition filter
billing_filter.day = GenericBQFilter[datetime](
- gte=datetime.strptime(query.start_date, '%Y-%m-%d')
- if query.start_date
- else None,
- lte=datetime.strptime(query.end_date, '%Y-%m-%d')
- if query.end_date
- else None,
+ gte=(
+ datetime.strptime(query.start_date, '%Y-%m-%d')
+ if query.start_date
+ else None
+ ),
+ lte=(
+ datetime.strptime(query.end_date, '%Y-%m-%d')
+ if query.end_date
+ else None
+ ),
)
return billing_filter
@@ -362,17 +386,19 @@ async def _append_total_running_cost(
total_monthly=(
total_monthly[COMPUTE]['ALL'] + total_monthly[STORAGE]['ALL']
),
- total_daily=(total_daily[COMPUTE]['ALL'] + total_daily[STORAGE]['ALL'])
- if is_current_month
- else None,
+ total_daily=(
+ (total_daily[COMPUTE]['ALL'] + total_daily[STORAGE]['ALL'])
+ if is_current_month
+ else None
+ ),
compute_monthly=total_monthly[COMPUTE]['ALL'],
- compute_daily=(total_daily[COMPUTE]['ALL'])
- if is_current_month
- else None,
+ compute_daily=(
+ (total_daily[COMPUTE]['ALL']) if is_current_month else None
+ ),
storage_monthly=total_monthly[STORAGE]['ALL'],
- storage_daily=(total_daily[STORAGE]['ALL'])
- if is_current_month
- else None,
+ storage_daily=(
+ (total_daily[STORAGE]['ALL']) if is_current_month else None
+ ),
details=all_details,
budget_spent=None,
budget=None,
@@ -422,17 +448,19 @@ async def _append_running_cost_records(
{
'field': key,
'total_monthly': monthly,
- 'total_daily': (compute_daily + storage_daily)
- if is_current_month
- else None,
+ 'total_daily': (
+ (compute_daily + storage_daily)
+ if is_current_month
+ else None
+ ),
'compute_monthly': compute_monthly,
'compute_daily': compute_daily,
'storage_monthly': storage_monthly,
'storage_daily': storage_daily,
'details': details,
- 'budget_spent': 100 * monthly / budget_monthly
- if budget_monthly
- else None,
+ 'budget_spent': (
+ 100 * monthly / budget_monthly if budget_monthly else None
+ ),
'budget': budget_monthly,
'last_loaded_day': last_loaded_day,
}
diff --git a/db/python/tables/bq/billing_daily_extended.py b/db/python/tables/bq/billing_daily_extended.py
index 009144911..5e2d2bf3c 100644
--- a/db/python/tables/bq/billing_daily_extended.py
+++ b/db/python/tables/bq/billing_daily_extended.py
@@ -1,9 +1,13 @@
+from datetime import datetime
+
+from google.cloud import bigquery
+
from api.settings import BQ_AGGREG_EXT_VIEW
from db.python.tables.bq.billing_base import (
BillingBaseTable,
time_optimisation_parameter,
)
-from models.models import BillingColumn
+from models.models import BillingBatchCostRecord, BillingColumn
class BillingDailyExtendedTable(BillingBaseTable):
@@ -49,3 +53,317 @@ async def get_extended_values(self, field: str):
# return empty list if no record found
return []
+
+ async def get_batch_cost_summary(
+ self,
+ start_time: datetime,
+ end_time: datetime,
+ batch_ids: list[str] | None,
+ ar_guid: str | None,
+ ) -> list[BillingBatchCostRecord]:
+ """
+ Get summary of AR run
+ """
+ query_parameters = [
+ bigquery.ScalarQueryParameter('start_time', 'TIMESTAMP', start_time),
+ bigquery.ScalarQueryParameter('end_time', 'TIMESTAMP', end_time),
+ ]
+
+ if ar_guid:
+ condition = 'ar_guid = @ar_guid'
+ query_parameters.append(
+ bigquery.ScalarQueryParameter('ar_guid', 'STRING', ar_guid)
+ )
+ elif batch_ids:
+ condition = 'batch_id in UNNEST(@batch_ids)'
+ query_parameters.append(
+ bigquery.ArrayQueryParameter('batch_ids', 'STRING', batch_ids)
+ )
+ else:
+ raise ValueError('Either ar_guid or batch_ids must be provided')
+
+ # The query will locate all the records for the given time range
+ # and batch_ids or ar_guid
+ # getting all records in one go to limit number of queries,
+ # cutting the total cost of BQ
+ # After getting all the needed data,
+ # aggregate:
+ # the cost per topic,
+ # get counts of batches,
+ # cost per hail, cromwell and dataproc
+ # start and end date of the usage
+ # total cost,
+ # including cost by sku
+
+ _query = f"""
+ -- get all data we need for aggregations in one go
+ WITH d AS (
+ SELECT topic, ar_guid,
+ CASE WHEN compute_category IS NULL AND batch_id IS NOT NULL THEN
+ 'hail batch' ELSE compute_category END AS category,
+ batch_id, CAST(job_id AS INT) job_id,
+ sku, cost,
+ usage_start_time, usage_end_time, sequencing_group, stage, wdl_task_name, cromwell_sub_workflow_name, cromwell_workflow_id,
+ JSON_VALUE(PARSE_JSON(labels), '$.batch_name') as batch_name,
+ JSON_VALUE(PARSE_JSON(labels), '$.job_name') as job_name
+ FROM `{self.table_name}`
+ WHERE day >= @start_time
+ AND day <= @end_time
+ AND {condition}
+ )
+ -- sku per ar-guid record
+ , arsku AS (
+ SELECT sku, SUM(cost) as cost
+ FROM d
+ WHERE cost != 0
+ GROUP BY sku
+ ORDER BY cost DESC
+ )
+ , arskuacc AS (
+ SELECT ARRAY_AGG(STRUCT(sku, cost)) AS skus,
+ FROM arsku
+ )
+ -- sequencing group record
+ ,seqgrp AS (
+ select sequencing_group, stage, sum(cost) as cost from d group by stage, sequencing_group
+ )
+ ,seqgrpacc AS (
+ SELECT ARRAY_AGG(STRUCT(sequencing_group, stage, cost)) as seq_groups
+ FROM seqgrp
+ )
+ -- topics record
+ , t3 AS (
+ SELECT topic, SUM(cost) AS cost
+ FROM d
+ WHERE topic IS NOT NULL
+ GROUP BY topic
+ ORDER BY cost DESC
+ )
+ , t3acc AS (
+ SELECT ARRAY_AGG(STRUCT(topic, cost)) AS topics,
+ FROM t3
+ )
+ -- category specific record
+ , cc AS (
+ SELECT category, sum(cost) AS cost,
+ CASE WHEN category = 'hail batch' THEN COUNT(DISTINCT batch_id) ELSE NULL END as workflows
+ FROM d
+ GROUP BY category
+ )
+ , ccacc AS (
+ SELECT ARRAY_AGG(STRUCT(cc.category, cc.cost, cc.workflows)) AS categories
+ FROM cc
+ )
+ -- total record per ar-guid
+ ,t AS (
+ SELECT STRUCT(
+ ar_guid, sum(d.cost) AS cost,
+ MIN(usage_start_time) AS usage_start_time,
+ max(usage_end_time) AS usage_end_time
+ ) AS total,
+ FROM d
+ GROUP BY ar_guid
+ )
+ -- hail batch job record
+ , jsku AS (
+ SELECT batch_id, job_id, sku, SUM(cost) as cost
+ FROM d
+ WHERE cost != 0
+ AND batch_id IS NOT NULL
+ AND job_id IS NOT NULL
+ GROUP BY batch_id, job_id, sku
+ ORDER BY batch_id, job_id, cost desc
+ )
+ , jskuacc AS (
+ SELECT batch_id, job_id, ARRAY_AGG(STRUCT(sku, cost)) as skus
+ FROM jsku
+ GROUP BY batch_id, job_id
+ )
+ ,j AS (
+ SELECT d.batch_id, d.job_id, d.job_name, sum(d.cost) AS cost,
+ MIN(usage_start_time) AS usage_start_time,
+ max(usage_end_time) AS usage_end_time
+ FROM d
+ WHERE d.batch_id IS NOT NULL AND d.job_id IS NOT NULL
+ GROUP BY d.batch_id, d.job_id, d.job_name
+ ORDER BY job_id
+ )
+ , jacc AS (
+ SELECT j.batch_id, ARRAY_AGG(STRUCT(j.job_id, j.batch_id, job_name, cost, usage_start_time, usage_end_time, jskuacc.skus)) AS jobs
+ from j INNER JOIN jskuacc on jskuacc.batch_id = j.batch_id AND jskuacc.job_id = j.job_id
+ GROUP BY j.batch_id
+ )
+ -- hail batch record
+ , bsku AS (
+ SELECT batch_id, sku, SUM(cost) as cost
+ FROM d
+ WHERE cost != 0
+ AND batch_id IS NOT NULL
+ GROUP BY batch_id, sku
+ ORDER BY batch_id, cost desc
+ )
+ , bskuacc AS (
+ SELECT batch_id, ARRAY_AGG(STRUCT(sku, cost)) as skus
+ FROM bsku
+ GROUP BY batch_id
+ )
+ ,b AS (
+ SELECT d.batch_id, d.batch_name,
+ sum(d.cost) AS cost,
+ MIN(d.usage_start_time) AS usage_start_time,
+ max(d.usage_end_time) AS usage_end_time,
+ COUNT(DISTINCT d.job_id) as jobs_cnt
+ FROM d
+ WHERE d.batch_id IS NOT NULL
+ GROUP BY batch_id, batch_name
+ ORDER BY batch_id
+ )
+ ,bseqgrp AS (
+ select batch_id, sequencing_group, stage, sum(cost) as cost from d group by batch_id, sequencing_group, stage
+ )
+ ,bseqgrpacc AS (
+ SELECT batch_id, ARRAY_AGG(STRUCT(sequencing_group, stage, cost)) as seq_groups
+ FROM bseqgrp
+ GROUP BY batch_id
+ )
+ ,bacc as (
+ SELECT ARRAY_AGG(STRUCT(b.batch_id, batch_name, cost, usage_start_time, usage_end_time, jobs_cnt, bskuacc.skus, jacc.jobs, bseqgrpacc.seq_groups)) AS batches
+ from b
+ INNER JOIN bskuacc on bskuacc.batch_id = b.batch_id
+ INNER JOIN jacc ON jacc.batch_id = b.batch_id
+ INNER JOIN bseqgrpacc ON bseqgrpacc.batch_id = b.batch_id
+ )
+ -- cromwell specific records:
+ -- wdl_task record
+ , wdlsku AS (
+ SELECT wdl_task_name, sku, SUM(cost) as cost
+ FROM d
+ WHERE cost != 0
+ AND wdl_task_name IS NOT NULL
+ GROUP BY wdl_task_name, sku
+ ORDER BY wdl_task_name, cost desc
+ )
+ , wdlskuacc AS (
+ SELECT wdl_task_name, ARRAY_AGG(STRUCT(sku, cost)) as skus
+ FROM wdlsku
+ GROUP BY wdl_task_name
+ )
+ , wdl AS (
+ SELECT d.wdl_task_name,
+ sum(d.cost) AS cost,
+ MIN(d.usage_start_time) AS usage_start_time,
+ max(d.usage_end_time) AS usage_end_time,
+ COUNT(DISTINCT d.job_id) as jobs_cnt
+ FROM d
+ WHERE d.wdl_task_name IS NOT NULL
+ GROUP BY wdl_task_name
+ ORDER BY wdl_task_name
+ )
+ ,wdlacc as (
+ SELECT ARRAY_AGG(STRUCT(wdl.wdl_task_name, cost, usage_start_time, usage_end_time, wdlskuacc.skus)) AS wdl_tasks
+ from wdl
+ INNER JOIN wdlskuacc on wdlskuacc.wdl_task_name = wdl.wdl_task_name
+ )
+ -- cromwell_workflow record
+ , cwisku AS (
+ SELECT cromwell_workflow_id, sku, SUM(cost) as cost
+ FROM d
+ WHERE cost != 0
+ AND cromwell_workflow_id IS NOT NULL
+ GROUP BY cromwell_workflow_id, sku
+ ORDER BY cromwell_workflow_id, cost desc
+ )
+ , cwiskuacc AS (
+ SELECT cromwell_workflow_id, ARRAY_AGG(STRUCT(sku, cost)) as skus
+ FROM cwisku
+ GROUP BY cromwell_workflow_id
+ )
+ , cwi AS (
+ SELECT d.cromwell_workflow_id,
+ sum(d.cost) AS cost,
+ MIN(d.usage_start_time) AS usage_start_time,
+ max(d.usage_end_time) AS usage_end_time,
+ COUNT(DISTINCT d.job_id) as jobs_cnt
+ FROM d
+ WHERE d.cromwell_workflow_id IS NOT NULL
+ GROUP BY cromwell_workflow_id
+ ORDER BY cromwell_workflow_id
+ )
+ ,cwiacc as (
+ SELECT ARRAY_AGG(STRUCT(cwi.cromwell_workflow_id, cost, usage_start_time, usage_end_time, cwiskuacc.skus)) AS cromwell_workflows
+ from cwi
+ INNER JOIN cwiskuacc on cwiskuacc.cromwell_workflow_id = cwi.cromwell_workflow_id
+ )
+ -- cromwell_sub_workflow record
+ , cswsku AS (
+ SELECT cromwell_sub_workflow_name, sku, SUM(cost) as cost
+ FROM d
+ WHERE cost != 0
+ AND cromwell_sub_workflow_name IS NOT NULL
+ GROUP BY cromwell_sub_workflow_name, sku
+ ORDER BY cromwell_sub_workflow_name, cost desc
+ )
+ , cswskuacc AS (
+ SELECT cromwell_sub_workflow_name, ARRAY_AGG(STRUCT(sku, cost)) as skus
+ FROM cswsku
+ GROUP BY cromwell_sub_workflow_name
+ )
+ , csw AS (
+ SELECT d.cromwell_sub_workflow_name,
+ sum(d.cost) AS cost,
+ MIN(d.usage_start_time) AS usage_start_time,
+ max(d.usage_end_time) AS usage_end_time,
+ COUNT(DISTINCT d.job_id) as jobs_cnt
+ FROM d
+ WHERE d.cromwell_sub_workflow_name IS NOT NULL
+ GROUP BY cromwell_sub_workflow_name
+ ORDER BY cromwell_sub_workflow_name
+ )
+ ,cswacc as (
+ SELECT ARRAY_AGG(STRUCT(csw.cromwell_sub_workflow_name, cost, usage_start_time, usage_end_time, cswskuacc.skus)) AS cromwell_sub_workflows
+ from csw
+ INNER JOIN cswskuacc on cswskuacc.cromwell_sub_workflow_name = csw.cromwell_sub_workflow_name
+ )
+ -- dataproc specific record
+ , dprocsku AS (
+ SELECT category as dataproc, sku, SUM(cost) as cost
+ FROM d
+ WHERE category = 'dataproc'
+ AND cost != 0
+ GROUP BY dataproc, sku
+ ORDER BY cost DESC
+ )
+ , dprocskuacc AS (
+ SELECT dataproc, ARRAY_AGG(STRUCT(sku, cost)) AS skus,
+ FROM dprocsku
+ GROUP BY dataproc
+ )
+ ,dproc AS (
+ SELECT category as dataproc, sum(cost) AS cost,
+ MIN(usage_start_time) AS usage_start_time,
+ max(usage_end_time) AS usage_end_time
+ FROM d
+ WHERE category = 'dataproc'
+ GROUP BY dataproc
+ )
+ , dprocacc as (
+ SELECT ARRAY_AGG(STRUCT(dproc.dataproc, cost, usage_start_time, usage_end_time, dprocskuacc.skus)) AS dataproc
+ from dproc
+ INNER JOIN dprocskuacc on dprocskuacc.dataproc = dproc.dataproc
+ )
+ -- merge all in one record
+ SELECT t.total, t3acc.topics, ccacc.categories, bacc.batches, arskuacc.skus,
+ wdlacc.wdl_tasks, cswacc.cromwell_sub_workflows, cwiacc.cromwell_workflows, seqgrpacc.seq_groups, dprocacc.dataproc
+ FROM t, t3acc, ccacc, bacc, arskuacc, wdlacc, cswacc, cwiacc, seqgrpacc, dprocacc
+
+ """
+
+ query_job_result = self._execute_query(_query, query_parameters, False)
+
+ if query_job_result:
+ return [
+ BillingBatchCostRecord.from_json(dict(row)) for row in query_job_result
+ ]
+
+ # return empty list if no record found
+ return []
diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py
index 8aab115a8..e782038ab 100644
--- a/db/python/tables/family_participant.py
+++ b/db/python/tables/family_participant.py
@@ -20,7 +20,7 @@ async def create_row(
paternal_id: int,
maternal_id: int,
affected: int,
- notes: str = None,
+ notes: str | None = None,
) -> Tuple[int, int]:
"""
Create a new sample, and add it to database
@@ -111,8 +111,8 @@ async def get_rows(
keys = [
'fp.family_id',
'p.id as individual_id',
- 'fp.paternal_participant_id',
- 'fp.maternal_participant_id',
+ 'fp.paternal_participant_id as paternal_id',
+ 'fp.maternal_participant_id as maternal_id',
'p.reported_sex as sex',
'fp.affected',
]
@@ -153,7 +153,7 @@ async def get_rows(
'sex',
'affected',
]
- ds = [dict(zip(ordered_keys, row)) for row in rows]
+ ds = [{k: row[k] for k in ordered_keys} for row in rows]
return ds
@@ -161,7 +161,7 @@ async def get_row(
self,
family_id: int,
participant_id: int,
- ):
+ ) -> dict | None:
"""Get a single row from the family_participant table"""
values: Dict[str, Any] = {
'family_id': family_id,
@@ -169,7 +169,13 @@ async def get_row(
}
_query = """
-SELECT fp.family_id, p.id as individual_id, fp.paternal_participant_id, fp.maternal_participant_id, p.reported_sex as sex, fp.affected
+SELECT
+ fp.family_id as family_id,
+ p.id as individual_id,
+ fp.paternal_participant_id as paternal_id,
+ fp.maternal_participant_id as maternal_id,
+ p.reported_sex as sex,
+ fp.affected
FROM family_participant fp
INNER JOIN family f ON f.id = fp.family_id
INNER JOIN participant p on fp.participant_id = p.id
@@ -177,6 +183,8 @@ async def get_row(
"""
row = await self.connection.fetch_one(_query, values)
+ if not row:
+ return None
ordered_keys = [
'family_id',
@@ -186,7 +194,7 @@ async def get_row(
'sex',
'affected',
]
- ds = dict(zip(ordered_keys, row))
+ ds = {k: row[k] for k in ordered_keys}
return ds
diff --git a/db/python/tables/participant_phenotype.py b/db/python/tables/participant_phenotype.py
index ea010588d..d50b5bbb7 100644
--- a/db/python/tables/participant_phenotype.py
+++ b/db/python/tables/participant_phenotype.py
@@ -44,7 +44,7 @@ async def add_key_value_rows(self, rows: List[Tuple[int, str, Any]]) -> None:
)
async def get_key_value_rows_for_participant_ids(
- self, participant_ids=List[int]
+ self, participant_ids: List[int]
) -> Dict[int, Dict[str, Any]]:
"""
Get (participant_id, description, value),
@@ -64,7 +64,9 @@ async def get_key_value_rows_for_participant_ids(
)
formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict)
for row in rows:
- pid, key, value = row
+ pid = row['participant_id']
+ key = row['description']
+ value = row['value']
formed_key_value_pairs[pid][key] = json.loads(value)
return formed_key_value_pairs
@@ -86,7 +88,9 @@ async def get_key_value_rows_for_all_participants(
rows = await self.connection.fetch_all(_query, {'project': project})
formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict)
for row in rows:
- pid, key, value = row
+ pid = row['participant_id']
+ key = row['description']
+ value = row['value']
formed_key_value_pairs[pid][key] = json.loads(value)
return formed_key_value_pairs
diff --git a/etl/load/main.py b/etl/load/main.py
index 65e7a2dc6..28f3b291d 100644
--- a/etl/load/main.py
+++ b/etl/load/main.py
@@ -1,6 +1,7 @@
import asyncio
import base64
import datetime
+import importlib.metadata
import json
import logging
import os
@@ -10,7 +11,6 @@
import flask
import functions_framework
import google.cloud.bigquery as bq
-import pkg_resources
from google.cloud import pubsub_v1, secretmanager
from metamist.parser.generic_parser import GenericParser # type: ignore
@@ -25,14 +25,23 @@
@lru_cache
def _get_bq_client():
+ assert BIGQUERY_TABLE, 'BIGQUERY_TABLE is not set'
+ assert BIGQUERY_LOG_TABLE, 'BIGQUERY_LOG_TABLE is not set'
return bq.Client()
@lru_cache
def _get_secret_manager():
+ assert ETL_ACCESSOR_CONFIG_SECRET, 'CONFIGURATION_SECRET is not set'
return secretmanager.SecretManagerServiceClient()
+@lru_cache
+def _get_pubsub_client():
+ assert NOTIFICATION_PUBSUB_TOPIC, 'NOTIFICATION_PUBSUB_TOPIC is not set'
+ return pubsub_v1.PublisherClient()
+
+
class ParsingStatus:
"""
Enum type to distinguish between sucess and failure of parsing
@@ -65,7 +74,7 @@ def call_parser(parser_obj, row_json) -> tuple[str, str]:
async def run_parser_capture_result(parser_obj, row_data, res, status):
try:
# TODO better error handling
- r = await parser_obj.from_json([row_data], confirm=False, dry_run=True)
+ r = await parser_obj.from_json(row_data, confirm=False)
res.append(r)
status.append(ParsingStatus.SUCCESS)
except Exception as e: # pylint: disable=broad-exception-caught
@@ -148,7 +157,7 @@ def process_rows(
# publish to notification pubsub
msg_title = 'Metamist ETL Load Failed'
try:
- pubsub_client = pubsub_v1.PublisherClient()
+ pubsub_client = _get_pubsub_client()
pubsub_client.publish(
NOTIFICATION_PUBSUB_TOPIC,
json.dumps({'title': msg_title} | log_record).encode(),
@@ -212,6 +221,15 @@ def etl_load(request: flask.Request):
'message': f'Missing or empty request_id: {jbody_str}',
}, 400
+ return process_request(request_id, delivery_attempt)
+
+
+def process_request(
+ request_id: str, delivery_attempt: int | None = None
+) -> tuple[dict, int]:
+ """
+ Process request_id, delivery_attempt and return result
+ """
# locate the request_id in bq
query = f"""
SELECT * FROM `{BIGQUERY_TABLE}` WHERE request_id = @request_id
@@ -322,13 +340,16 @@ def get_parser_instance(
accessor_config: dict[
str,
- list[
- dict[
- Literal['name']
- | Literal['parser_name']
- | Literal['default_parameters'],
- Any,
- ]
+ dict[
+ Literal['parsers'],
+ list[
+ dict[
+ Literal['name']
+ | Literal['parser_name']
+ | Literal['default_parameters'],
+ Any,
+ ]
+ ],
],
] = get_accessor_config()
@@ -341,7 +362,7 @@ def get_parser_instance(
etl_accessor_config = next(
(
accessor_config
- for accessor_config in accessor_config[submitting_user]
+ for accessor_config in accessor_config[submitting_user].get('parsers', [])
if accessor_config['name'].strip(STRIP_CHARS)
== request_type.strip(STRIP_CHARS)
),
@@ -387,7 +408,8 @@ def prepare_parser_map() -> dict[str, type[GenericParser]]:
loop through metamist_parser entry points and create map of parsers
"""
parser_map = {}
- for entry_point in pkg_resources.iter_entry_points('metamist_parser'):
+
+ for entry_point in importlib.metadata.entry_points().get('metamist_parser'):
parser_cls = entry_point.load()
parser_short_name, parser_version = parser_cls.get_info()
parser_map[f'{parser_short_name}/{parser_version}'] = parser_cls
diff --git a/etl/test/test_etl_load.py b/etl/test/test_etl_load.py
index 203bc9f33..a1ccbb4c1 100644
--- a/etl/test/test_etl_load.py
+++ b/etl/test/test_etl_load.py
@@ -200,13 +200,15 @@ def test_get_parser_instance_success(
"""Test get_parser_instance success"""
mock_get_accessor_config.return_value = {
- 'user@test.com': [
- {
- 'name': 'test/v1',
- 'default_parameters': {},
- # 'parser_name': 'test',
- }
- ]
+ 'user@test.com': {
+ 'parsers': [
+ {
+ 'name': 'test/v1',
+ 'default_parameters': {},
+ # 'parser_name': 'test',
+ }
+ ]
+ }
}
mock_prepare_parser_map.return_value = {
@@ -234,13 +236,15 @@ def test_get_parser_instance_different_parser_name(
"""Test get_parser_instance success"""
mock_get_accessor_config.return_value = {
- 'user@test.com': [
- {
- 'name': 'test/v1',
- 'default_parameters': {'project': 'test'},
- 'parser_name': 'different_parser/name',
- }
- ]
+ 'user@test.com': {
+ 'parsers': [
+ {
+ 'name': 'test/v1',
+ 'default_parameters': {'project': 'test'},
+ 'parser_name': 'different_parser/name',
+ }
+ ]
+ }
}
mock_prepare_parser_map.return_value = {
@@ -284,12 +288,14 @@ def test_get_parser_no_matching_config(
"""Test get_parser_instance success"""
mock_get_accessor_config.return_value = {
- 'user@test.com': [
- {
- 'name': 'test/v1',
- 'default_parameters': {'project': 'test'},
- }
- ]
+ 'user@test.com': {
+ 'parsers': [
+ {
+ 'name': 'test/v1',
+ 'default_parameters': {'project': 'test'},
+ }
+ ]
+ }
}
# this doesn't need to be mocked as it fails before here
@@ -316,12 +322,14 @@ def test_get_parser_no_matching_parser(
"""Test get_parser_instance success"""
mock_get_accessor_config.return_value = {
- 'user@test.com': [
- {
- 'name': 'a/b',
- 'default_parameters': {'project': 'test'},
- }
- ]
+ 'user@test.com': {
+ 'parsers': [
+ {
+ 'name': 'a/b',
+ 'default_parameters': {'project': 'test'},
+ }
+ ]
+ }
}
mock_prepare_parser_map.return_value = {
diff --git a/metamist/parser/generic_parser.py b/metamist/parser/generic_parser.py
index 8dc2de5d7..fbe95136b 100644
--- a/metamist/parser/generic_parser.py
+++ b/metamist/parser/generic_parser.py
@@ -534,6 +534,9 @@ async def from_json(self, rows, confirm=False, dry_run=False):
If no participants are present, groups samples by their IDs.
For each sample, gets its sequencing groups by their keys. For each sequencing group, groups assays and analyses.
"""
+ if not isinstance(rows, list):
+ rows = [rows]
+
await self.validate_rows(rows)
# one participant with no value
diff --git a/models/models/__init__.py b/models/models/__init__.py
index ce5868cb1..8b100662f 100644
--- a/models/models/__init__.py
+++ b/models/models/__init__.py
@@ -12,10 +12,10 @@
from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal
from models.models.audit_log import AuditLogId, AuditLogInternal
from models.models.billing import (
+ BillingBatchCostRecord,
BillingColumn,
BillingCostBudgetRecord,
BillingCostDetailsRecord,
- BillingHailBatchCostRecord,
BillingInternal,
BillingTotalCostQueryModel,
BillingTotalCostRecord,
diff --git a/models/models/billing.py b/models/models/billing.py
index 6b5a98fbf..83cb7586f 100644
--- a/models/models/billing.py
+++ b/models/models/billing.py
@@ -135,25 +135,13 @@ def str_to_enum(cls, value: str) -> 'BillingColumn':
def raw_cols(cls) -> list[str]:
"""Return list of raw column names"""
return [
- BillingColumn.ID.value,
BillingColumn.TOPIC.value,
- BillingColumn.SERVICE.value,
BillingColumn.SKU.value,
BillingColumn.USAGE_START_TIME.value,
BillingColumn.USAGE_END_TIME.value,
- BillingColumn.PROJECT.value,
BillingColumn.LABELS.value,
- BillingColumn.SYSTEM_LABELS.value,
- BillingColumn.LOCATION.value,
- BillingColumn.EXPORT_TIME.value,
BillingColumn.COST.value,
BillingColumn.CURRENCY.value,
- BillingColumn.CURRENCY_CONVERSION_RATE.value,
- BillingColumn.USAGE.value,
- BillingColumn.CREDITS.value,
- BillingColumn.INVOICE.value,
- BillingColumn.COST_TYPE.value,
- BillingColumn.ADJUSTMENT_INFO.value,
]
@classmethod
@@ -362,9 +350,33 @@ def from_json(record):
)
-class BillingHailBatchCostRecord(SMBase):
+class BillingBatchCostRecord(SMBase):
"""Return class for the Billing Cost by batch_id/ar_guid"""
- ar_guid: str | None
- batch_ids: list[str] | None
- costs: list[dict] | None
+ total: dict | None
+ topics: list[dict] | None
+ categories: list[dict] | None
+ batches: list[dict] | None
+ skus: list[dict] | None
+ seq_groups: list[dict] | None
+
+ wdl_tasks: list[dict] | None
+ cromwell_sub_workflows: list[dict] | None
+ cromwell_workflows: list[dict] | None
+ dataproc: list[dict] | None
+
+ @staticmethod
+ def from_json(record):
+ """Create BillingBatchCostRecord from json"""
+ return BillingBatchCostRecord(
+ total=record.get('total'),
+ topics=record.get('topics'),
+ categories=record.get('categories'),
+ batches=record.get('batches'),
+ skus=record.get('skus'),
+ seq_groups=record.get('seq_groups'),
+ wdl_tasks=record.get('wdl_tasks'),
+ cromwell_sub_workflows=record.get('cromwell_sub_workflows'),
+ cromwell_workflows=record.get('cromwell_workflows'),
+ dataproc=record.get('dataproc'),
+ )
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 3bb86aa87..c73f19f70 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -7,7 +7,7 @@ flake8-bugbear
nest-asyncio
pre-commit
pylint
-testcontainers[mariadb]==3.7.1
+testcontainers[mariadb]>=4.0.0
types-PyMySQL
# some strawberry dependency
strawberry-graphql[debug-server]==0.206.0
diff --git a/requirements.txt b/requirements.txt
index 1f9640ea2..172fa8fe0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,8 +17,8 @@ uvicorn==0.18.3
fastapi[all]==0.85.1
strawberry-graphql[fastapi]==0.206.0
python-multipart==0.0.5
-databases[mysql]==0.6.1
-SQLAlchemy==1.4.41
+databases[mysql]==0.9.0
+SQLAlchemy==2.0.28
cryptography>=41.0.0
python-dateutil==2.8.2
slack-sdk==3.20.2
diff --git a/test/test_api_billing.py b/test/test_api_billing.py
index 4811187a7..2f5c00665 100644
--- a/test/test_api_billing.py
+++ b/test/test_api_billing.py
@@ -1,13 +1,14 @@
# pylint: disable=protected-access too-many-public-methods
+import json
from test.testbase import run_as_sync
from test.testbqbase import BqTest
from unittest.mock import patch
from api.routes import billing
from models.models import (
+ BillingBatchCostRecord,
BillingColumn,
BillingCostBudgetRecord,
- BillingHailBatchCostRecord,
BillingTotalCostQueryModel,
BillingTotalCostRecord,
)
@@ -54,15 +55,28 @@ async def test_get_cost_by_ar_guid(
Test get_cost_by_ar_guid function
"""
ar_guid = 'test_ar_guid'
- mockup_record = BillingHailBatchCostRecord(
- ar_guid=ar_guid, batch_ids=None, costs=None
- )
+ mockup_record_json = {
+ 'total': {'ar_guid': ar_guid},
+ 'topics': None,
+ 'categories': None,
+ 'batches': [],
+ 'skus': None,
+ 'seq_groups': None,
+ 'wdl_tasks': None,
+ 'cromwell_sub_workflows': None,
+ 'cromwell_workflows': None,
+ 'dataproc': None,
+ }
+
+ mockup_record = [BillingBatchCostRecord.from_json(mockup_record_json)]
mock_get_billing_layer.return_value = self.layer
mock_get_cost_by_ar_guid.return_value = mockup_record
- records = await billing.get_cost_by_ar_guid(
+ response = await billing.get_cost_by_ar_guid(
ar_guid, author=TEST_API_BILLING_USER
)
- self.assertEqual(mockup_record, records)
+ self.assertEqual(
+ [mockup_record_json], json.loads(response.body.decode('utf-8'))
+ )
@run_as_sync
@patch('api.routes.billing._get_billing_layer_from')
@@ -75,15 +89,28 @@ async def test_get_cost_by_batch_id(
"""
ar_guid = 'test_ar_guid'
batch_id = 'test_batch_id'
- mockup_record = BillingHailBatchCostRecord(
- ar_guid=ar_guid, batch_ids=[batch_id], costs=None
- )
+ mockup_record_json = {
+ 'total': {'ar_guid': ar_guid},
+ 'topics': None,
+ 'categories': None,
+ 'batches': [{'batch_id': batch_id}],
+ 'skus': None,
+ 'seq_groups': None,
+ 'wdl_tasks': None,
+ 'cromwell_sub_workflows': None,
+ 'cromwell_workflows': None,
+ 'dataproc': None,
+ }
+
+ mockup_record = [BillingBatchCostRecord.from_json(mockup_record_json)]
mock_get_billing_layer.return_value = self.layer
mock_get_cost_by_batch_id.return_value = mockup_record
- records = await billing.get_cost_by_batch_id(
+ response = await billing.get_cost_by_batch_id(
batch_id, author=TEST_API_BILLING_USER
)
- self.assertEqual(mockup_record, records)
+ self.assertEqual(
+ [mockup_record_json], json.loads(response.body.decode('utf-8'))
+ )
@run_as_sync
@patch('api.routes.billing._get_billing_layer_from')
diff --git a/test/test_bq_billing_base.py b/test/test_bq_billing_base.py
index 655b40486..609983790 100644
--- a/test/test_bq_billing_base.py
+++ b/test/test_bq_billing_base.py
@@ -362,9 +362,10 @@ def test_execute_query_results_not_as_list(self):
given_bq_results = [[], [123], ['a', 'b', 'c']]
for bq_result in given_bq_results:
# mock BigQuery result
- self.bq_client.query.return_value = bq_result
+ self.bq_result.result.return_value = bq_result
+ self.bq_result.total_bytes_processed = 0
results = self.table_obj._execute_query(
- sql_query, sql_params, results_as_list=False
+ sql_query, sql_params, results_as_list=True
)
self.assertEqual(bq_result, results)
@@ -380,9 +381,10 @@ def test_execute_query_with_sql_params(self):
given_bq_results = [[], [123], ['a', 'b', 'c']]
for bq_result in given_bq_results:
# mock BigQuery result
- self.bq_client.query.return_value = bq_result
+ self.bq_result.result.return_value = bq_result
+ self.bq_result.total_bytes_processed = 0
results = self.table_obj._execute_query(
- sql_query, sql_params, results_as_list=False
+ sql_query, sql_params, results_as_list=True
)
self.assertEqual(bq_result, results)
diff --git a/test/test_layers_billing.py b/test/test_layers_billing.py
index be13f3fa3..535d374ad 100644
--- a/test/test_layers_billing.py
+++ b/test/test_layers_billing.py
@@ -8,11 +8,7 @@
from db.python.layers.billing import BillingLayer
from models.enums import BillingSource
-from models.models import (
- BillingColumn,
- BillingHailBatchCostRecord,
- BillingTotalCostQueryModel,
-)
+from models.models import BillingColumn, BillingTotalCostQueryModel
class TestBillingLayer(BqTest):
@@ -359,7 +355,11 @@ async def test_get_running_cost(self):
@run_as_sync
async def test_get_cost_by_ar_guid(self):
- """Test get_cost_by_ar_guid"""
+ """
+ Test get_cost_by_ar_guid
+ This test only paths in the layer,
+ the logic and processing is tested in test/test_bq_billing_base.py
+ """
layer = BillingLayer(self.connection)
@@ -367,9 +367,7 @@ async def test_get_cost_by_ar_guid(self):
records = await layer.get_cost_by_ar_guid(ar_guid=None)
# return empty record
- self.assertEqual(
- BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records
- )
+ self.assertEqual([], records)
# dummy ar_guid, no mockup data, return empty results
dummy_ar_guid = '12345678'
@@ -377,13 +375,11 @@ async def test_get_cost_by_ar_guid(self):
# return empty record
self.assertEqual(
- BillingHailBatchCostRecord(ar_guid=dummy_ar_guid, batch_ids=[], costs=[]),
+ [],
records,
)
- # dummy ar_guid, mockup batch_id
-
- # mock BigQuery result
+ # mock BigQuery first query result
given_start_day = datetime.datetime(2023, 1, 1, 0, 0)
given_end_day = datetime.datetime(2023, 1, 1, 2, 3)
dummy_batch_id = '12345'
@@ -401,19 +397,21 @@ async def test_get_cost_by_ar_guid(self):
self.bq_result.result.return_value = mock_rows
records = await layer.get_cost_by_ar_guid(ar_guid=dummy_ar_guid)
- # returns ar_guid, batch_id and empty cost as those were not mocked up
+ # returns empty list as those were not mocked up
# we do not need to test cost calculation here,
# as those are tested in test/test_bq_billing_base.py
self.assertEqual(
- BillingHailBatchCostRecord(
- ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}]
- ),
+ [],
records,
)
@run_as_sync
async def test_get_cost_by_batch_id(self):
- """Test get_cost_by_batch_id"""
+ """
+ Test get_cost_by_batch_id
+ This test only paths in the layer,
+ the logic and processing is tested in test/test_bq_billing_base.py
+ """
layer = BillingLayer(self.connection)
@@ -421,9 +419,7 @@ async def test_get_cost_by_batch_id(self):
records = await layer.get_cost_by_batch_id(batch_id=None)
# return empty record
- self.assertEqual(
- BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records
- )
+ self.assertEqual([], records)
# dummy ar_guid, no mockup data, return empty results
dummy_batch_id = '12345'
@@ -431,7 +427,7 @@ async def test_get_cost_by_batch_id(self):
# return empty record
self.assertEqual(
- BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]),
+ [],
records,
)
@@ -459,12 +455,10 @@ async def test_get_cost_by_batch_id(self):
self.bq_result.result.return_value = mock_rows
records = await layer.get_cost_by_batch_id(batch_id=dummy_batch_id)
- # returns ar_guid, batch_id and empty cost as those were not mocked up
+ # returns elmpty list as those were not mocked up
# we do not need to test cost calculation here,
# as those are tested in test/test_bq_billing_base.py
self.assertEqual(
- BillingHailBatchCostRecord(
- ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}]
- ),
+ [],
records,
)
diff --git a/test/test_web.py b/test/test_web.py
index cc5a0bb63..cc7ffad99 100644
--- a/test/test_web.py
+++ b/test/test_web.py
@@ -221,7 +221,7 @@ async def test_project_summary_empty(self):
seqr_sync_types=[],
)
- self.assertEqual(expected, result)
+ self.assertDataclassEqual(expected, result)
@run_as_sync
async def test_project_summary_single_entry(self):
@@ -232,7 +232,7 @@ async def test_project_summary_single_entry(self):
result = await self.webl.get_project_summary(token=0, grid_filter=[])
result.participants = []
- self.assertEqual(SINGLE_PARTICIPANT_RESULT, result)
+ self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, result)
@run_as_sync
async def test_project_summary_to_external(self):
@@ -289,7 +289,7 @@ async def project_summary_with_filter_with_results(self):
],
)
filtered_result_success.participants = []
- self.assertEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success)
+ self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success)
@run_as_sync
async def project_summary_with_filter_no_results(self):
@@ -323,7 +323,7 @@ async def project_summary_with_filter_no_results(self):
seqr_sync_types=[],
)
- self.assertEqual(empty_result, filtered_result_empty)
+ self.assertDataclassEqual(empty_result, filtered_result_empty)
@run_as_sync
async def test_project_summary_multiple_participants(self):
@@ -376,7 +376,7 @@ async def test_project_summary_multiple_participants(self):
two_samples_result.participants = []
- self.assertEqual(expected_data_two_samples, two_samples_result)
+ self.assertDataclassEqual(expected_data_two_samples, two_samples_result)
@run_as_sync
async def test_project_summary_multiple_participants_and_filter(self):
@@ -436,7 +436,7 @@ async def test_project_summary_multiple_participants_and_filter(self):
)
two_samples_result_filtered.participants = []
- self.assertEqual(
+ self.assertDataclassEqual(
expected_data_two_samples_filtered, two_samples_result_filtered
)
@@ -499,7 +499,9 @@ async def test_field_with_space(self):
seqr_sync_types=[],
)
- self.assertEqual(expected_data_two_samples_filtered, test_field_with_space)
+ self.assertDataclassEqual(
+ expected_data_two_samples_filtered, test_field_with_space
+ )
@run_as_sync
async def test_project_summary_inactive_sequencing_group(self):
diff --git a/test/testbase.py b/test/testbase.py
index 04c8e1d2d..fa60ee66f 100644
--- a/test/testbase.py
+++ b/test/testbase.py
@@ -1,6 +1,7 @@
# pylint: disable=invalid-overridden-method
import asyncio
+import dataclasses
import logging
import os
import socket
@@ -96,10 +97,10 @@ async def setup():
logger = logging.getLogger()
try:
set_all_access(True)
- db = MySqlContainer('mariadb:10.8.3')
+ db = MySqlContainer('mariadb:11.2.2')
port_to_expose = find_free_port()
# override the default port to map the container to
- db.with_bind_ports(db.port_to_expose, port_to_expose)
+ db.with_bind_ports(db.port, port_to_expose)
logger.disabled = True
db.start()
logger.disabled = False
@@ -111,7 +112,7 @@ async def setup():
con_string = db.get_connection_url()
con_string = 'mysql://' + con_string.split('://', maxsplit=1)[1]
- lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.MYSQL_DATABASE}'
+ lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.dbname}'
# apply the liquibase schema
command = [
'liquibase',
@@ -120,8 +121,8 @@ async def setup():
*('--url', lcon_string),
*('--driver', 'org.mariadb.jdbc.Driver'),
*('--classpath', db_prefix + '/mariadb-java-client-3.0.3.jar'),
- *('--username', db.MYSQL_USER),
- *('--password', db.MYSQL_PASSWORD),
+ *('--username', db.username),
+ *('--password', db.password),
'update',
]
subprocess.check_output(command, stderr=subprocess.STDOUT)
@@ -175,7 +176,7 @@ async def setup():
def tearDownClass(cls) -> None:
db = cls.dbs.get(cls.__name__)
if db:
- db.exec(f'DROP DATABASE {db.MYSQL_DATABASE};')
+ db.exec(f'DROP DATABASE {db.dbname};')
db.stop()
def setUp(self) -> None:
@@ -224,6 +225,19 @@ async def audit_log_id(self):
"""Get audit_log_id for the test"""
return await self.connection.audit_log_id()
+ def assertDataclassEqual(self, a, b):
+ """Assert two dataclasses are equal"""
+
+ def to_dict(obj):
+ d = dataclasses.asdict(obj)
+ for k, v in d.items():
+ if dataclasses.is_dataclass(v):
+ d[k] = to_dict(v)
+ return d
+
+ self.maxDiff = None
+ self.assertDictEqual(to_dict(a), to_dict(b))
+
class DbIsolatedTest(DbTest):
"""
diff --git a/web/package-lock.json b/web/package-lock.json
index 62926084c..f04895d4d 100644
--- a/web/package-lock.json
+++ b/web/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "metamist",
- "version": "6.4.0",
+ "version": "6.6.2",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "metamist",
- "version": "6.4.0",
+ "version": "6.6.2",
"dependencies": {
"@apollo/client": "^3.7.3",
"@artsy/fresnel": "^6.2.1",
@@ -35,6 +35,7 @@
"react-responsive": "^9.0.2",
"react-router-dom": "^6.0.1",
"react-syntax-highlighter": "^15.4.4",
+ "react-virtuoso": "^4.6.3",
"recharts": "^2.8.0",
"remark-gfm": "^3.0.1",
"remark-toc": "^8.0.1",
@@ -10171,6 +10172,18 @@
"react-dom": ">=16.6.0"
}
},
+ "node_modules/react-virtuoso": {
+ "version": "4.6.3",
+ "resolved": "https://registry.npmjs.org/react-virtuoso/-/react-virtuoso-4.6.3.tgz",
+ "integrity": "sha512-NcoSsf4B0OCx7U8i2s+VWe8b9e+FWzcN/5ly4hKjErynBzGONbWORZ1C5amUlWrPi6+HbUQ2PjnT4OpyQIpP9A==",
+ "engines": {
+ "node": ">=10"
+ },
+ "peerDependencies": {
+ "react": ">=16 || >=17 || >= 18",
+ "react-dom": ">=16 || >=17 || >= 18"
+ }
+ },
"node_modules/readable-stream": {
"version": "3.6.2",
"dev": true,
diff --git a/web/package.json b/web/package.json
index 852999cee..1c1527f30 100644
--- a/web/package.json
+++ b/web/package.json
@@ -30,6 +30,7 @@
"react-responsive": "^9.0.2",
"react-router-dom": "^6.0.1",
"react-syntax-highlighter": "^15.4.4",
+ "react-virtuoso": "^4.6.3",
"recharts": "^2.8.0",
"remark-gfm": "^3.0.1",
"remark-toc": "^8.0.1",
diff --git a/web/src/pages/billing/BillingCostByAnalysis.tsx b/web/src/pages/billing/BillingCostByAnalysis.tsx
index 87e1ef153..d18b5d58d 100644
--- a/web/src/pages/billing/BillingCostByAnalysis.tsx
+++ b/web/src/pages/billing/BillingCostByAnalysis.tsx
@@ -4,10 +4,12 @@ import { Button, Card, Grid, Input, Message, Select, Dropdown } from 'semantic-u
import SearchIcon from '@mui/icons-material/Search'
import LoadingDucks from '../../shared/components/LoadingDucks/LoadingDucks'
-import { BillingApi, BillingTotalCostRecord } from '../../sm-api'
-import HailBatchGrid from './components/HailBatchGrid'
+import { BillingApi, BillingTotalCostRecord, AnalysisApi } from '../../sm-api'
+import BatchGrid from './components/BatchGrid'
+
import { getMonthStartDate } from '../../shared/utilities/monthStartEndDate'
import generateUrl from '../../shared/utilities/generateUrl'
+import { List } from 'lodash'
enum SearchType {
Ar_guid,
@@ -27,6 +29,32 @@ const BillingCostByAnalysis: React.FunctionComponent = () => {
const [data, setData] = React.useState
Ar guid: f5a065d2-c51f-46b7-a920-a89b639fc4ba
- Batch id: 430604, 430605
+ Batch id: 430604
+
+ Hail Batch + DataProc: 433599
+
+ Cromwell: ec3f961f-7e16-4fb0-a3e3-9fc93006ab42
+
+ Large Hail Batch with 7.5K jobs: a449eea5-7150-441a-9ffe-bd71587c3fe2