Skip to content

Commit

Permalink
GraphQL Billing funcs first idea.
Browse files Browse the repository at this point in the history
  • Loading branch information
milo-hyben committed Jan 31, 2024
1 parent fdcc325 commit 0318d5e
Showing 1 changed file with 185 additions and 7 deletions.
192 changes: 185 additions & 7 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# type: ignore
# flake8: noqa
# pylint: disable=no-value-for-parameter,redefined-builtin,missing-function-docstring,unused-argument
# pylint: disable=no-value-for-parameter,redefined-builtin,missing-function-docstring,unused-argument,too-many-lines
"""
Schema for GraphQL.
Note, we silence a lot of linting here because GraphQL looks at type annotations
and defaults to decide the GraphQL schema, so it might not necessarily look correct.
"""
import datetime
from collections import Counter
from inspect import isclass

import strawberry
Expand Down Expand Up @@ -35,6 +36,7 @@
AssayInternal,
AuditLogInternal,
BillingColumn,
BillingHailBatchCostRecord,
BillingInternal,
BillingTotalCostQueryModel,
FamilyInternal,
Expand Down Expand Up @@ -71,6 +73,23 @@ async def m(info: Info) -> list[str]:
GraphQLEnum = strawberry.type(type('GraphQLEnum', (object,), enum_methods))


def to_camel_case(test_str: str) -> str:
# using for loop to convert string to camel case
result = ''
capitalize_next = False
for char in test_str:
if char == '_':
capitalize_next = True
else:
if capitalize_next:
result += char.upper()
capitalize_next = False
else:
result += char

return result


@strawberry.type
class GraphQLProject:
"""Project GraphQL model"""
Expand Down Expand Up @@ -624,6 +643,106 @@ def from_internal(internal: BillingInternal) -> 'GraphQLBilling':
)


@strawberry.type
class GraphQLBatchCostRecord:
"""GraphQL Billing"""

id: str | None
ar_guid: str | None
batch_id: str | None
job_id: str | None
day: datetime.date | None

topic: str | None
namespace: str | None
name: str | None

sku: str | None
cost: float | None
url: str | None

@staticmethod
def from_json(json: dict) -> 'GraphQLBatchCostRecord':
return GraphQLBatchCostRecord(
id=json.get('id'),
ar_guid=json.get('ar_guid'),
batch_id=json.get('batch_id'),
job_id=json.get('job_id'),
day=json.get('day'),
topic=json.get('topic'),
namespace=json.get('namespace'),
name=json.get('name'),
sku=json.get('sku'),
cost=json.get('cost'),
url=json.get('url'),
)

@staticmethod
def from_internal(
internal: BillingHailBatchCostRecord, fields: list[str] | None = None
) -> list['GraphQLBatchCostRecord']:
"""
TODO sum the costs based on selected fields
"""
results = []
if not internal:
return results

ar_guid = internal.ar_guid

if fields is None:
for rec in internal.costs:
results.append(
GraphQLBatchCostRecord(
id=rec.get('id'),
ar_guid=ar_guid,
batch_id=rec.get('batch_id'),
job_id=rec.get('job_id'),
day=rec.get('day'),
topic=rec.get('topic'),
namespace=rec.get('namespace'),
name=rec.get('batch_name'),
sku=rec.get('batch_resource'),
cost=rec.get('cost'),
url=rec.get('url'),
)
)
else:
# we need to aggregate sum(cost) by fields
# if cost not present, then do distinct like operation?

# prepare the fields
aggregated = Counter()

class_fields = list(
GraphQLBatchCostRecord.__dict__['__annotations__'].keys()
)
for rec in internal.costs:
# create key based on fields
key = '_'.join(
[
str(rec.get(f))
if 'cost' not in f and to_camel_case(f) in fields
else ''
for f in class_fields
]
)
aggregated[key] += rec.get('cost', 0)

for key, cost in aggregated.items():
# split the key back to fields
fields = key.split('_')
# map to class_fields
record = {}
for pos in range(len(class_fields)):
record[class_fields[pos]] = fields[pos]

record['cost'] = cost
results.append(GraphQLBatchCostRecord.from_json(record))

return results


@strawberry.type
class Query:
"""GraphQL Queries"""
Expand Down Expand Up @@ -762,13 +881,75 @@ async def my_projects(self, info: Info) -> list[GraphQLProject]:
)
return [GraphQLProject.from_internal(p) for p in projects]

"""
"""
TODO split inot 4 or 5 different functions
e.g. billing_by_batch_id, billing_by_ar_guid, billing_by_topic, billing_by_gcp_project
"""

@staticmethod
def get_billing_layer(info: Info) -> BillingLayer:
# TODO is there a better way to get the BQ connection?
connection = info.context['connection']
bg_connection = BqConnection(connection.author)
return BillingLayer(bg_connection)

@staticmethod
async def extract_fields(info: Info) -> list[str]:
from graphql.parser import GraphQLParser

parser = GraphQLParser()
body = await info.context.get('request').json()
ast = parser.parse(body['query'])
fields = [f.name for f in ast.definitions[0].selections[-1].selections]
print('fields', fields)
return fields

@strawberry.field
async def billing_by_batch_id(
self,
info: Info,
batch_id: str,
) -> list[GraphQLBatchCostRecord]:
slayer = Query.get_billing_layer(info)
result = await slayer.get_cost_by_batch_id(batch_id)
fields = await Query.extract_fields(info)
return GraphQLBatchCostRecord.from_internal(result, fields)

@strawberry.field
async def billing_by_ar_guid(
self,
info: Info,
ar_guid: str,
) -> list[GraphQLBatchCostRecord]:
slayer = Query.get_billing_layer(info)
result = await slayer.get_cost_by_ar_guid(ar_guid)
fields = await Query.extract_fields(info)
return GraphQLBatchCostRecord.from_internal(result, fields)

@strawberry.field
async def billing_by_topic(
self,
info: Info,
topic: str | None = None,
day: GraphQLFilter[datetime.datetime] | None = None,
cost: GraphQLFilter[float] | None = None,
) -> list[GraphQLBilling]:
# slayer = Query.get_billing_layer(info)
return []

@strawberry.field
async def billing_by_gcp_project(
self,
info: Info,
gcp_project: str | None = None,
day: GraphQLFilter[datetime.datetime] | None = None,
cost: GraphQLFilter[float] | None = None,
) -> list[GraphQLBilling]:
# slayer = Query.get_billing_layer(info)
return []

@strawberry.field
async def billing(
async def billing_todel(
self,
info: Info,
batch_id: str | None = None,
Expand All @@ -785,10 +966,7 @@ async def billing(
# if not is_billing_enabled():
# raise ValueError('Billing is not enabled')

# TODO is there a better way to get the BQ connection?
connection = info.context['connection']
bg_connection = BqConnection(connection.author)
slayer = BillingLayer(bg_connection)
slayer = get_billing_layer(info)

if ar_guid:
res = await slayer.get_cost_by_ar_guid(ar_guid)
Expand Down

0 comments on commit 0318d5e

Please sign in to comment.