diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 436a1c334..ed263ad91 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -1,6 +1,6 @@ # type: ignore # flake8: noqa -# pylint: disable=no-value-for-parameter,redefined-builtin,missing-function-docstring,unused-argument +# pylint: disable=no-value-for-parameter,redefined-builtin,missing-function-docstring,unused-argument,too-many-lines """ Schema for GraphQL. @@ -8,6 +8,7 @@ and defaults to decide the GraphQL schema, so it might not necessarily look correct. """ import datetime +from collections import Counter from inspect import isclass import strawberry @@ -18,8 +19,10 @@ from api.graphql.filters import GraphQLFilter, GraphQLMetaFilter from api.graphql.loaders import LoaderKeys, get_context from db.python import enum_tables +from db.python.gcp_connect import BqConnection from db.python.layers import AnalysisLayer, SampleLayer, SequencingGroupLayer from db.python.layers.assay import AssayLayer +from db.python.layers.billing import BillingLayer from db.python.layers.family import FamilyLayer from db.python.tables.analysis import AnalysisFilter from db.python.tables.assay import AssayFilter @@ -32,6 +35,10 @@ AnalysisInternal, AssayInternal, AuditLogInternal, + BillingColumn, + BillingHailBatchCostRecord, + BillingInternal, + BillingTotalCostQueryModel, FamilyInternal, ParticipantInternal, Project, @@ -66,6 +73,23 @@ async def m(info: Info) -> list[str]: GraphQLEnum = strawberry.type(type('GraphQLEnum', (object,), enum_methods)) +def to_camel_case(test_str: str) -> str: + # using for loop to convert string to camel case + result = '' + capitalize_next = False + for char in test_str: + if char == '_': + capitalize_next = True + else: + if capitalize_next: + result += char.upper() + capitalize_next = False + else: + result += char + + return result + + @strawberry.type class GraphQLProject: """Project GraphQL model""" @@ -592,6 +616,133 @@ async def sample(self, info: Info, root: 'GraphQLAssay') -> GraphQLSample: return GraphQLSample.from_internal(sample) +@strawberry.type +class GraphQLBilling: + """GraphQL Billing""" + + id: str | None + ar_guid: str | None + gcp_project: str | None + topic: str | None + batch_id: str | None + cost_category: str | None + day: datetime.date | None + cost: float | None + + @staticmethod + def from_internal(internal: BillingInternal) -> 'GraphQLBilling': + return GraphQLBilling( + id=internal.id, + ar_guid=internal.ar_guid, + gcp_project=internal.gcp_project, + topic=internal.topic, + batch_id=internal.batch_id, + cost_category=internal.cost_category, + day=internal.day, + cost=internal.cost, + ) + + +@strawberry.type +class GraphQLBatchCostRecord: + """GraphQL Billing""" + + id: str | None + ar_guid: str | None + batch_id: str | None + job_id: str | None + day: datetime.date | None + + topic: str | None + namespace: str | None + name: str | None + + sku: str | None + cost: float | None + url: str | None + + @staticmethod + def from_json(json: dict) -> 'GraphQLBatchCostRecord': + return GraphQLBatchCostRecord( + id=json.get('id'), + ar_guid=json.get('ar_guid'), + batch_id=json.get('batch_id'), + job_id=json.get('job_id'), + day=json.get('day'), + topic=json.get('topic'), + namespace=json.get('namespace'), + name=json.get('name'), + sku=json.get('sku'), + cost=json.get('cost'), + url=json.get('url'), + ) + + @staticmethod + def from_internal( + internal: BillingHailBatchCostRecord, fields: list[str] | None = None + ) -> list['GraphQLBatchCostRecord']: + """ + TODO sum the costs based on selected fields + """ + results = [] + if not internal: + return results + + ar_guid = internal.ar_guid + + if fields is None: + for rec in internal.costs: + results.append( + GraphQLBatchCostRecord( + id=rec.get('id'), + ar_guid=ar_guid, + batch_id=rec.get('batch_id'), + job_id=rec.get('job_id'), + day=rec.get('day'), + topic=rec.get('topic'), + namespace=rec.get('namespace'), + name=rec.get('batch_name'), + sku=rec.get('batch_resource'), + cost=rec.get('cost'), + url=rec.get('url'), + ) + ) + else: + # we need to aggregate sum(cost) by fields + # if cost not present, then do distinct like operation? + + # prepare the fields + aggregated = Counter() + + class_fields = list( + GraphQLBatchCostRecord.__dict__['__annotations__'].keys() + ) + for rec in internal.costs: + # create key based on fields + key = '_'.join( + [ + str(rec.get(f)) + if 'cost' not in f and to_camel_case(f) in fields + else '' + for f in class_fields + ] + ) + aggregated[key] += rec.get('cost', 0) + + for key, cost in aggregated.items(): + # split the key back to fields + fields = key.split('_') + # map to class_fields + record = {} + for pos in range(len(class_fields)): + record[class_fields[pos]] = fields[pos] + + record['cost'] = cost + results.append(GraphQLBatchCostRecord.from_json(record)) + + return results + + @strawberry.type class Query: """GraphQL Queries""" @@ -730,6 +881,145 @@ async def my_projects(self, info: Info) -> list[GraphQLProject]: ) return [GraphQLProject.from_internal(p) for p in projects] + """ + TODO split inot 4 or 5 different functions + e.g. billing_by_batch_id, billing_by_ar_guid, billing_by_topic, billing_by_gcp_project + """ + + @staticmethod + def get_billing_layer(info: Info) -> BillingLayer: + # TODO is there a better way to get the BQ connection? + connection = info.context['connection'] + bg_connection = BqConnection(connection.author) + return BillingLayer(bg_connection) + + @staticmethod + async def extract_fields(info: Info) -> list[str]: + from graphql.parser import GraphQLParser + + parser = GraphQLParser() + body = await info.context.get('request').json() + ast = parser.parse(body['query']) + fields = [f.name for f in ast.definitions[0].selections[-1].selections] + print('fields', fields) + return fields + + @strawberry.field + async def billing_by_batch_id( + self, + info: Info, + batch_id: str, + ) -> list[GraphQLBatchCostRecord]: + slayer = Query.get_billing_layer(info) + result = await slayer.get_cost_by_batch_id(batch_id) + fields = await Query.extract_fields(info) + return GraphQLBatchCostRecord.from_internal(result, fields) + + @strawberry.field + async def billing_by_ar_guid( + self, + info: Info, + ar_guid: str, + ) -> list[GraphQLBatchCostRecord]: + slayer = Query.get_billing_layer(info) + result = await slayer.get_cost_by_ar_guid(ar_guid) + fields = await Query.extract_fields(info) + return GraphQLBatchCostRecord.from_internal(result, fields) + + @strawberry.field + async def billing_by_topic( + self, + info: Info, + topic: str | None = None, + day: GraphQLFilter[datetime.datetime] | None = None, + cost: GraphQLFilter[float] | None = None, + ) -> list[GraphQLBilling]: + # slayer = Query.get_billing_layer(info) + return [] + + @strawberry.field + async def billing_by_gcp_project( + self, + info: Info, + gcp_project: str | None = None, + day: GraphQLFilter[datetime.datetime] | None = None, + cost: GraphQLFilter[float] | None = None, + ) -> list[GraphQLBilling]: + # slayer = Query.get_billing_layer(info) + return [] + + @strawberry.field + async def billing_todel( + self, + info: Info, + batch_id: str | None = None, + ar_guid: str | None = None, + topic: str | None = None, + gcp_project: str | None = None, + day: GraphQLFilter[datetime.datetime] | None = None, + cost: GraphQLFilter[float] | None = None, + ) -> list[GraphQLBilling]: + """ + This is the first raw implementation of Billing inside GraphQL + """ + # TODO check billing is enabled e.g.: + # if not is_billing_enabled(): + # raise ValueError('Billing is not enabled') + + slayer = get_billing_layer(info) + + if ar_guid: + res = await slayer.get_cost_by_ar_guid(ar_guid) + if res: + # only show the costs + res = res.costs + + elif batch_id: + res = await slayer.get_cost_by_batch_id(batch_id) + if res: + # only show the costs + res = res.costs + + else: + # TODO construct fields from request.body (selected attributes) + # For time being, just use these fields + fields = [ + BillingColumn.DAY, + BillingColumn.COST, + BillingColumn.COST_CATEGORY, + ] + + filters = {} + if topic: + filters['topic'] = topic + fields.append(BillingColumn.TOPIC) + if gcp_project: + filters['gcp_project'] = gcp_project + fields.append(BillingColumn.GCP_PROJECT) + + if day: + all_days_vals = day.all_values() + start_date = min(all_days_vals).strftime('%Y-%m-%d') + end_date = max(all_days_vals).strftime('%Y-%m-%d') + else: + # TODO we need to limit to small time periods to avoid huge charges + # If day is not selected use only current day records + start_date = datetime.datetime.now().strftime('%Y-%m-%d') + end_date = start_date + + query = BillingTotalCostQueryModel( + fields=fields, + start_date=start_date, + end_date=end_date, + filters=filters, + ) + res = await slayer.get_total_cost(query) + + return [ + GraphQLBilling.from_internal(BillingInternal.from_db(**dict(p))) + for p in res + ] + schema = strawberry.Schema( query=Query, mutation=None, extensions=[QueryDepthLimiter(max_depth=10)] diff --git a/test/test_api_billing.py b/test/test_api_billing.py new file mode 100644 index 000000000..f523c8f0f --- /dev/null +++ b/test/test_api_billing.py @@ -0,0 +1,21 @@ +import unittest +from test.testbase import run_as_sync + +from api.routes.billing import get_gcp_projects, is_billing_enabled + + +class TestApiBilling(unittest.TestCase): + """Test API Billing routes""" + + def test_is_billing_enabled(self): + """ """ + result = is_billing_enabled() + self.assertEqual(False, result) + + @run_as_sync + async def test_get_gcp_projects(self): + """ """ + with self.assertRaises(ValueError) as context: + _result = await get_gcp_projects('test_user') + + self.assertTrue('Billing is not enabled' in str(context.exception)) diff --git a/test/test_api_utils.py b/test/test_api_utils.py new file mode 100644 index 000000000..44d8b3ed5 --- /dev/null +++ b/test/test_api_utils.py @@ -0,0 +1,69 @@ +import unittest +from datetime import datetime + +from api.utils.dates import ( + get_invoice_month_range, + parse_date_only_string, + reformat_datetime, +) + + +class TestApiUtils(unittest.TestCase): + """Test API utils functions""" + + def test_parse_date_only_string(self): + """ """ + result_none = parse_date_only_string(None) + self.assertEqual(None, result_none) + + result_date = parse_date_only_string('2021-01-10') + self.assertEqual(2021, result_date.year) + self.assertEqual(1, result_date.month) + self.assertEqual(10, result_date.day) + + # test exception + invalid_date_str = '123456789' + with self.assertRaises(ValueError) as context: + parse_date_only_string(invalid_date_str) + + self.assertTrue( + f'Date could not be converted: {invalid_date_str}' in str(context.exception) + ) + + def test_get_invoice_month_range(self): + jan_2021 = datetime.strptime('2021-01-10', '%Y-%m-%d').date() + res_jan_2021 = get_invoice_month_range(jan_2021) + + # there is 3 (INVOICE_DAY_DIFF) days difference between invoice month st and end + self.assertEqual( + (datetime(2020, 12, 29).date(), datetime(2021, 2, 3).date()), + res_jan_2021, + ) + + dec_2021 = datetime.strptime('2021-12-10', '%Y-%m-%d').date() + res_dec_2021 = get_invoice_month_range(dec_2021) + + # there is 3 (INVOICE_DAY_DIFF) days difference between invoice month st and end + self.assertEqual( + (datetime(2021, 11, 28).date(), datetime(2022, 1, 3).date()), + res_dec_2021, + ) + + def test_reformat_datetime(self): + in_format = '%Y-%m-%d' + out_format = '%d/%m/%Y' + + result_none = reformat_datetime(None, in_format, out_format) + self.assertEqual(None, result_none) + + result_formatted = reformat_datetime('2021-11-09', in_format, out_format) + self.assertEqual('09/11/2021', result_formatted) + + # test exception + invalid_date_str = '123456789' + with self.assertRaises(ValueError) as context: + reformat_datetime(invalid_date_str, in_format, out_format) + + self.assertTrue( + f'Date could not be converted: {invalid_date_str}' in str(context.exception) + ) diff --git a/test/test_bq_generic_filters.py b/test/test_bq_generic_filters.py new file mode 100644 index 000000000..bdd21b224 --- /dev/null +++ b/test/test_bq_generic_filters.py @@ -0,0 +1,270 @@ +import dataclasses +import unittest +from datetime import datetime +from enum import Enum +from typing import Any + +from google.cloud import bigquery + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel + + +@dataclasses.dataclass(kw_only=True) +class GenericBQFilterTest(GenericBQFilterModel): + """Test model for GenericBQFilter""" + + test_string: GenericBQFilter[str] | None = None + test_int: GenericBQFilter[int] | None = None + test_float: GenericBQFilter[float] | None = None + test_dt: GenericBQFilter[datetime] | None = None + test_dict: dict[str, GenericBQFilter[str]] | None = None + test_enum: GenericBQFilter[Enum] | None = None + test_any: Any | None = None + + +class BGFilterTestEnum(str, Enum): + """Simple Enum classs""" + + ID = 'id' + VALUE = 'value' + + +class TestGenericBQFilters(unittest.TestCase): + """Test generic filters SQL generation""" + + def test_basic_no_override(self): + """Test that the basic filter converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_eq', sql) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_basic_override(self): + """Test that the basic filter with an override converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql({'test_string': 't.string'}) + + self.assertEqual('t.string = @t_string_eq', sql) + self.assertDictEqual( + { + 't_string_eq': bigquery.ScalarQueryParameter( + 't_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_string(self): + """ + Test that a single value filtered using the "in" operator + gets converted to an eq operation + """ + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=['test'])) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_in_eq', sql) + self.assertDictEqual( + { + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123 + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(gt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > @test_int_gt', sql) + self.assertDictEqual( + { + 'test_int_gt': bigquery.ScalarQueryParameter( + 'test_int_gt', 'INT64', value + ) + }, + values, + ) + + def test_single_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123.456 + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(gte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float >= @test_float_gte', sql) + self.assertDictEqual( + { + 'test_float_gte': bigquery.ScalarQueryParameter( + 'test_float_gte', 'FLOAT64', value + ) + }, + values, + ) + + def test_single_datetime(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + datetime_str = '2021-10-08 01:02:03' + value = datetime.strptime(datetime_str, '%Y-%m-%d %H:%M:%S') + filter_ = GenericBQFilterTest(test_dt=GenericBQFilter(lt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_dt < TIMESTAMP(@test_dt_lt)', sql) + self.assertDictEqual( + { + 'test_dt_lt': bigquery.ScalarQueryParameter( + 'test_dt_lt', 'STRING', datetime_str + ) + }, + values, + ) + + def test_single_enum(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = BGFilterTestEnum.ID + filter_ = GenericBQFilterTest(test_enum=GenericBQFilter(lte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_enum <= @test_enum_lte', sql) + self.assertDictEqual( + { + 'test_enum_lte': bigquery.ScalarQueryParameter( + 'test_enum_lte', 'STRING', value.value + ) + }, + values, + ) + + def test_in_multiple_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1, 2] + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int IN UNNEST(@test_int_in)', sql) + self.assertDictEqual( + { + 'test_int_in': bigquery.ArrayQueryParameter( + 'test_int_in', 'INT64', value + ) + }, + values, + ) + + def test_in_multiple_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1.0, 2.0] + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float IN UNNEST(@test_float_in)', sql) + self.assertDictEqual( + { + 'test_float_in': bigquery.ArrayQueryParameter( + 'test_float_in', 'FLOAT64', value + ) + }, + values, + ) + + def test_in_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string IN UNNEST(@test_string_in)', sql) + self.assertDictEqual( + { + 'test_string_in': bigquery.ArrayQueryParameter( + 'test_string_in', 'STRING', value + ) + }, + values, + ) + + def test_nin_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string NOT IN UNNEST(@test_string_nin)', sql) + self.assertDictEqual( + { + 'test_string_nin': bigquery.ArrayQueryParameter( + 'test_string_nin', 'STRING', value + ) + }, + values, + ) + + def test_in_and_eq_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value, eq='B')) + sql, values = filter_.to_sql() + + self.assertEqual( + 'test_string = @test_string_eq AND test_string = @test_string_in_eq', + sql, + ) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'B' + ), + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'A' + ), + }, + values, + ) + + def test_fail_none_in_tuple(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = (None,) + + # check if ValueError is raised + with self.assertRaises(ValueError) as context: + filter_ = GenericBQFilterTest(test_any=value) + filter_.to_sql() + + self.assertTrue( + 'There is very likely a trailing comma on the end of ' + 'GenericBQFilterTest.test_any. ' + 'If you actually want a tuple of length one with the value = (None,), ' + 'then use dataclasses.field(default_factory=lambda: (None,))' + in str(context.exception) + ) diff --git a/test/test_generic_filters.py b/test/test_generic_filters.py index 2c1348076..b5598be54 100644 --- a/test/test_generic_filters.py +++ b/test/test_generic_filters.py @@ -53,3 +53,54 @@ def test_in_multiple(self): self.assertEqual('test_int IN :test_int_in', sql) self.assertDictEqual({'test_int_in': value}, values) + + def test_gt_single(self): + """ + Test that a single value filtered using the "gt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > :test_int_gt', sql) + self.assertDictEqual({'test_int_gt': 123}, values) + + def test_gte_single(self): + """ + Test that a single value filtered using the "gte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int >= :test_int_gte', sql) + self.assertDictEqual({'test_int_gte': 123}, values) + + def test_lt_single(self): + """ + Test that a single value filtered using the "lt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int < :test_int_lt', sql) + self.assertDictEqual({'test_int_lt': 123}, values) + + def test_lte_single(self): + """ + Test that a single value filtered using the "lte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int <= :test_int_lte', sql) + self.assertDictEqual({'test_int_lte': 123}, values) + + def test_not_in_multiple(self): + """ + Test that values filtered using the "nin" operator convert as expected + """ + value = [1, 2] + filter_ = GenericFilterTest(test_int=GenericFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int NOT IN :test_int_nin', sql) + self.assertDictEqual({'test_int_nin': value}, values)