diff --git a/.bumpversion.cfg b/.bumpversion.cfg index c66ac522c..9ae63ecc9 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 6.5.0 +current_version = 6.6.2 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P[A-z0-9-]+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b827b7f14..545ead223 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml exclude: '\.*conda/.*' @@ -14,13 +14,13 @@ repos: - id: check-added-large-files - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.33.0 + rev: v0.38.0 hooks: - id: markdownlint args: ["--config", ".markdownlint.json"] - repo: https://github.com/ambv/black - rev: 23.3.0 + rev: 23.12.1 hooks: - id: black args: [.] @@ -29,7 +29,7 @@ repos: exclude: ^metamist/ - repo: https://github.com/PyCQA/flake8 - rev: "6.0.0" + rev: "6.1.0" hooks: - id: flake8 additional_dependencies: [flake8-bugbear, flake8-quotes] @@ -45,7 +45,7 @@ repos: # mypy - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.1 + rev: v1.8.0 hooks: - id: mypy args: [ diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 66c559330..c5ed6354e 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -10,10 +10,10 @@ from strawberry.dataloader import DataLoader from api.utils import get_projectless_db_connection, group_by -from db.python.connect import NotFoundError from db.python.layers import ( AnalysisLayer, AssayLayer, + AuditLogLayer, FamilyLayer, ParticipantLayer, SampleLayer, @@ -24,16 +24,18 @@ from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter -from db.python.utils import GenericFilter, ProjectId +from db.python.utils import GenericFilter, NotFoundError from models.models import ( AnalysisInternal, AssayInternal, FamilyInternal, ParticipantInternal, Project, + ProjectId, SampleInternal, SequencingGroupInternal, ) +from models.models.audit_log import AuditLogInternal class LoaderKeys(enum.Enum): @@ -44,6 +46,9 @@ class LoaderKeys(enum.Enum): PROJECTS_FOR_IDS = 'projects_for_id' + AUDIT_LOGS_BY_IDS = 'audit_logs_by_ids' + AUDIT_LOGS_BY_ANALYSIS_IDS = 'audit_logs_by_analysis_ids' + ANALYSES_FOR_SEQUENCING_GROUPS = 'analyses_for_sequencing_groups' ASSAYS_FOR_SAMPLES = 'sequences_for_samples' @@ -168,6 +173,31 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]: return connected_data_loader_caller +@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS) +async def load_audit_logs_by_ids( + audit_log_ids: list[int], connection +) -> list[AuditLogInternal | None]: + """ + DataLoader: get_audit_logs_by_ids + """ + alayer = AuditLogLayer(connection) + logs = await alayer.get_for_ids(audit_log_ids) + logs_by_id = {log.id: log for log in logs} + return [logs_by_id.get(a) for a in audit_log_ids] + + +@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS) +async def load_audit_logs_by_analysis_ids( + analysis_ids: list[int], connection +) -> list[list[AuditLogInternal]]: + """ + DataLoader: get_audit_logs_by_analysis_ids + """ + alayer = AnalysisLayer(connection) + logs = await alayer.get_audit_logs_by_analysis_ids(analysis_ids) + return [logs.get(a) or [] for a in analysis_ids] + + @connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list) async def load_assays_by_samples( connection, ids, filter: AssayFilter @@ -332,15 +362,13 @@ async def load_projects_for_ids(project_ids: list[int], connection) -> list[Proj """ Get projects by IDs """ - pttable = ProjectPermissionsTable(connection.connection) - - ids = [int(p) for p in project_ids] + pttable = ProjectPermissionsTable(connection) projects = await pttable.get_and_check_access_to_projects_for_ids( - user=connection.author, project_ids=ids, readonly=True + user=connection.author, project_ids=project_ids, readonly=True ) p_by_id = {p.id: p for p in projects} - projects = [p_by_id.get(p) for p in ids] + projects = [p_by_id.get(p) for p in project_ids] return [p for p in projects if p is not None] diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 8c040d316..ad7236220 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -39,6 +39,7 @@ AnalysisInternal, AssayInternal, Cohort, + AuditLogInternal, FamilyInternal, ParticipantInternal, Project, @@ -278,6 +279,29 @@ async def cohort( return [GraphQLCohort.from_internal(c) for c in cohorts] +@strawberry.type +class GraphQLAuditLog: + """AuditLog GraphQL model""" + + id: int + author: str + timestamp: datetime.datetime + ar_guid: str | None + comment: str | None + meta: strawberry.scalars.JSON + + @staticmethod + def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog': + return GraphQLAuditLog( + id=audit_log.id, + author=audit_log.author, + timestamp=audit_log.timestamp, + ar_guid=audit_log.ar_guid, + comment=audit_log.comment, + meta=audit_log.meta, + ) + + @strawberry.type class GraphQLAnalysis: """Analysis GraphQL model""" @@ -289,7 +313,6 @@ class GraphQLAnalysis: timestamp_completed: datetime.datetime | None = None active: bool meta: strawberry.scalars.JSON - author: str @staticmethod def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis': @@ -301,7 +324,6 @@ def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis': timestamp_completed=internal.timestamp_completed, active=internal.active, meta=internal.meta, - author=internal.author, ) @strawberry.field @@ -318,6 +340,14 @@ async def project(self, info: Info, root: 'GraphQLAnalysis') -> GraphQLProject: project = await loader.load(root.project) return GraphQLProject.from_internal(project) + @strawberry.field + async def audit_logs( + self, info: Info, root: 'GraphQLAnalysis' + ) -> list[GraphQLAuditLog]: + loader = info.context[LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS] + audit_logs = await loader.load(root.id) + return [GraphQLAuditLog.from_internal(audit_log) for audit_log in audit_logs] + @strawberry.type class GraphQLFamily: @@ -371,6 +401,7 @@ class GraphQLParticipant: karyotype: str | None project_id: strawberry.Private[int] + audit_log_id: strawberry.Private[int | None] @staticmethod def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant': @@ -382,6 +413,7 @@ def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant': reported_gender=internal.reported_gender, karyotype=internal.karyotype, project_id=internal.project, + audit_log_id=internal.audit_log_id, ) @strawberry.field @@ -423,6 +455,16 @@ async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProjec project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) + @strawberry.field + async def audit_log( + self, info: Info, root: 'GraphQLParticipant' + ) -> GraphQLAuditLog | None: + if root.audit_log_id is None: + return None + loader = info.context[LoaderKeys.AUDIT_LOGS_BY_IDS] + audit_log = await loader.load(root.audit_log_id) + return GraphQLAuditLog.from_internal(audit_log) + @strawberry.type class GraphQLSample: @@ -433,7 +475,6 @@ class GraphQLSample: active: bool meta: strawberry.scalars.JSON type: str - author: str | None # keep as integers, because they're useful to reference in the fields below internal_id: strawberry.Private[int] @@ -441,14 +482,13 @@ class GraphQLSample: project_id: strawberry.Private[int] @staticmethod - def from_internal(sample: SampleInternal): + def from_internal(sample: SampleInternal) -> 'GraphQLSample': return GraphQLSample( id=sample_id_format(sample.id), external_id=sample.external_id, active=sample.active, meta=sample.meta, type=sample.type, - author=sample.author, # internals internal_id=sample.id, participant_id=sample.participant_id, @@ -533,21 +573,20 @@ class GraphQLSequencingGroup: technology: str platform: str meta: strawberry.scalars.JSON - external_ids: strawberry.scalars.JSON | None + external_ids: strawberry.scalars.JSON internal_id: strawberry.Private[int] sample_id: strawberry.Private[int] @staticmethod def from_internal(internal: SequencingGroupInternal) -> 'GraphQLSequencingGroup': - # print(internal) return GraphQLSequencingGroup( id=sequencing_group_id_format(internal.id), type=internal.type, technology=internal.technology, platform=internal.platform, meta=internal.meta, - external_ids=internal.external_ids, + external_ids=internal.external_ids or {}, # internal internal_id=internal.id, sample_id=internal.sample_id, @@ -574,7 +613,7 @@ async def analyses( loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS] project_id_map = {} if project: - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) project_ids = project.all_values() projects = await ptable.get_and_check_access_to_projects_for_names( user=connection.author, project_names=project_ids, readonly=True @@ -625,7 +664,7 @@ def from_internal(internal: AssayInternal) -> 'GraphQLAssay': id=internal.id, type=internal.type, meta=internal.meta, - external_ids=internal.external_ids, + external_ids=internal.external_ids or {}, # internal sample_id=internal.sample_id, ) @@ -685,7 +724,7 @@ async def cohort( @strawberry.field() async def project(self, info: Info, name: str) -> GraphQLProject: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) project = await ptable.get_and_check_access_to_project_for_name( user=connection.author, project_name=name, readonly=True ) @@ -704,7 +743,7 @@ async def sample( active: GraphQLFilter[bool] | None = None, ) -> list[GraphQLSample]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) slayer = SampleLayer(connection) if not id and not project: @@ -757,7 +796,7 @@ async def sequencing_groups( ) -> list[GraphQLSequencingGroup]: connection = info.context['connection'] sglayer = SequencingGroupLayer(connection) - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) if not (project or sample_id or id): raise ValueError('Must filter by project, sample or id') @@ -823,7 +862,7 @@ async def family(self, info: Info, family_id: int) -> GraphQLFamily: @strawberry.field async def my_projects(self, info: Info) -> list[GraphQLProject]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) projects = await ptable.get_projects_accessible_by_user( connection.author, readonly=True ) diff --git a/api/routes/analysis.py b/api/routes/analysis.py index a4c8c6f95..6aaa86645 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -104,10 +104,12 @@ async def create_analysis( atable = AnalysisLayer(connection) + if analysis.author: + # special tracking here, if we can't catch it through the header + connection.on_behalf_of = analysis.author + analysis_id = await atable.create_analysis( analysis.to_internal(), - # analysis-runner: usage is tracked through `on_behalf_of` - author=analysis.author, ) return analysis_id @@ -226,7 +228,7 @@ async def query_analyses( if not query.projects: raise ValueError('Must specify "projects"') - pt = ProjectPermissionsTable(connection=connection.connection) + pt = ProjectPermissionsTable(connection) projects = await pt.get_and_check_access_to_projects_for_names( user=connection.author, project_names=query.projects, readonly=True ) @@ -249,7 +251,7 @@ async def get_analysis_runner_log( atable = AnalysisLayer(connection) project_ids = None if project_names: - pt = ProjectPermissionsTable(connection=connection.connection) + pt = ProjectPermissionsTable(connection) project_ids = await pt.get_project_ids_from_names_and_user( connection.author, project_names, readonly=True ) @@ -335,7 +337,7 @@ async def get_proportionate_map( } } """ - pt = ProjectPermissionsTable(connection=connection.connection) + pt = ProjectPermissionsTable(connection) project_ids = await pt.get_project_ids_from_names_and_user( connection.author, projects, readonly=True ) diff --git a/api/routes/assay.py b/api/routes/assay.py index d0b386ccb..3770d5a76 100644 --- a/api/routes/assay.py +++ b/api/routes/assay.py @@ -79,7 +79,7 @@ async def get_assays_by_criteria( ): """Get assays by criteria""" assay_layer = AssayLayer(connection) - pt = ProjectPermissionsTable(connection.connection) + pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if projects: diff --git a/api/routes/billing.py b/api/routes/billing.py index bdc0d8b52..7d93a599c 100644 --- a/api/routes/billing.py +++ b/api/routes/billing.py @@ -1,28 +1,48 @@ """ Billing routes """ -from fastapi import APIRouter from async_lru import alru_cache +from fastapi import APIRouter -from api.settings import BILLING_CACHE_RESPONSE_TTL -from api.utils.db import ( - BqConnection, - get_author, -) +from api.settings import BILLING_CACHE_RESPONSE_TTL, BQ_AGGREG_VIEW +from api.utils.db import BqConnection, get_author from db.python.layers.billing import BillingLayer -from models.models.billing import ( +from models.enums import BillingSource +from models.models import ( BillingColumn, BillingCostBudgetRecord, - BillingQueryModel, - BillingRowRecord, - BillingTotalCostRecord, + BillingHailBatchCostRecord, BillingTotalCostQueryModel, + BillingTotalCostRecord, ) - router = APIRouter(prefix='/billing', tags=['billing']) +def _get_billing_layer_from(author: str) -> BillingLayer: + """ + Initialise billing + """ + if not is_billing_enabled(): + raise ValueError('Billing is not enabled') + + connection = BqConnection(author) + billing_layer = BillingLayer(connection) + return billing_layer + + +@router.get( + '/is-billing-enabled', + response_model=bool, + operation_id='isBillingEnabled', +) +def is_billing_enabled() -> bool: + """ + Return true if billing ie enabled, false otherwise + """ + return BQ_AGGREG_VIEW is not None + + @router.get( '/gcp-projects', response_model=list[str], @@ -33,8 +53,7 @@ async def get_gcp_projects( author: str = get_author, ) -> list[str]: """Get list of all GCP projects in database""" - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_gcp_projects() return records @@ -49,8 +68,7 @@ async def get_topics( author: str = get_author, ) -> list[str]: """Get list of all topics in database""" - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_topics() return records @@ -65,8 +83,7 @@ async def get_cost_categories( author: str = get_author, ) -> list[str]: """Get list of all service description / cost categories in database""" - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_cost_categories() return records @@ -87,8 +104,7 @@ async def get_skus( There is over 400 Skus so limit is required Results are sorted ASC """ - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_skus(limit, offset) return records @@ -106,8 +122,7 @@ async def get_datasets( Get list of all datasets in database Results are sorted ASC """ - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_datasets() return records @@ -125,8 +140,7 @@ async def get_sequencing_types( Get list of all sequencing_types in database Results are sorted ASC """ - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_sequencing_types() return records @@ -144,8 +158,7 @@ async def get_stages( Get list of all stages in database Results are sorted ASC """ - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_stages() return records @@ -163,12 +176,65 @@ async def get_sequencing_groups( Get list of all sequencing_groups in database Results are sorted ASC """ - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_sequencing_groups() return records +@router.get( + '/compute-categories', + response_model=list[str], + operation_id='getComputeCategories', +) +@alru_cache(ttl=BILLING_CACHE_RESPONSE_TTL) +async def get_compute_categories( + author: str = get_author, +) -> list[str]: + """ + Get list of all compute categories in database + Results are sorted ASC + """ + billing_layer = _get_billing_layer_from(author) + records = await billing_layer.get_compute_categories() + return records + + +@router.get( + '/cromwell-sub-workflow-names', + response_model=list[str], + operation_id='getCromwellSubWorkflowNames', +) +@alru_cache(ttl=BILLING_CACHE_RESPONSE_TTL) +async def get_cromwell_sub_workflow_names( + author: str = get_author, +) -> list[str]: + """ + Get list of all cromwell_sub_workflow_names in database + Results are sorted ASC + """ + billing_layer = _get_billing_layer_from(author) + records = await billing_layer.get_cromwell_sub_workflow_names() + return records + + +@router.get( + '/wdl-task-names', + response_model=list[str], + operation_id='getWdlTaskNames', +) +@alru_cache(ttl=BILLING_CACHE_RESPONSE_TTL) +async def get_wdl_task_names( + author: str = get_author, +) -> list[str]: + """ + Get list of all wdl_task_names in database + Results are sorted ASC + """ + billing_layer = _get_billing_layer_from(author) + records = await billing_layer.get_wdl_task_names() + return records + + @router.get( '/invoice-months', response_model=list[str], @@ -182,36 +248,58 @@ async def get_invoice_months( Get list of all invoice months in database Results are sorted DESC """ - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_invoice_months() return records -@router.post( - '/query', response_model=list[BillingRowRecord], operation_id='queryBilling' +@router.get( + '/namespaces', + response_model=list[str], + operation_id='getNamespaces', ) -@alru_cache(maxsize=10, ttl=BILLING_CACHE_RESPONSE_TTL) -async def query_billing( - query: BillingQueryModel, - limit: int = 10, +@alru_cache(ttl=BILLING_CACHE_RESPONSE_TTL) +async def get_namespaces( author: str = get_author, -) -> list[BillingRowRecord]: +) -> list[str]: """ - Get Billing records by some criteria, date is required to minimize BQ cost + Get list of all namespaces in database + Results are sorted DESC + """ + billing_layer = _get_billing_layer_from(author) + records = await billing_layer.get_namespaces() + return records - E.g. - { - "topic": ["hail"], - "date": "2023-03-02", - "cost_category": ["Hail compute Credit"] - } +@router.get( + '/cost-by-ar-guid/{ar_guid}', + response_model=BillingHailBatchCostRecord, + operation_id='costByArGuid', +) +@alru_cache(maxsize=10, ttl=BILLING_CACHE_RESPONSE_TTL) +async def get_cost_by_ar_guid( + ar_guid: str, + author: str = get_author, +) -> BillingHailBatchCostRecord: + """Get Hail Batch costs by AR GUID""" + billing_layer = _get_billing_layer_from(author) + records = await billing_layer.get_cost_by_ar_guid(ar_guid) + return records - """ - connection = BqConnection(author) - billing_layer = BillingLayer(connection) - records = await billing_layer.query(query.to_filter(), limit) + +@router.get( + '/cost-by-batch-id/{batch_id}', + response_model=BillingHailBatchCostRecord, + operation_id='costByBatchId', +) +@alru_cache(maxsize=10, ttl=BILLING_CACHE_RESPONSE_TTL) +async def get_cost_by_batch_id( + batch_id: str, + author: str = get_author, +) -> BillingHailBatchCostRecord: + """Get Hail Batch costs by Batch ID""" + billing_layer = _get_billing_layer_from(author) + records = await billing_layer.get_cost_by_batch_id(batch_id) return records @@ -341,12 +429,87 @@ async def get_total_cost( "order_by": {"cost": true} } - """ + 12. Get total cost by compute_category order by cost DESC: - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + { + "fields": ["compute_category"], + "start_date": "2023-11-10", + "end_date": "2023-11-10", + "order_by": {"cost": true} + } + + 13. Get total cost by cromwell_sub_workflow_name, order by cost DESC: + + { + "fields": ["cromwell_sub_workflow_name"], + "start_date": "2023-11-10", + "end_date": "2023-11-10", + "order_by": {"cost": true} + } + + 14. Get total cost by sku for given cromwell_workflow_id, order by cost DESC: + + { + "fields": ["sku"], + "start_date": "2023-11-10", + "end_date": "2023-11-10", + "filters": {"cromwell_workflow_id": "cromwell-00448f7b-8ef3-4d22-80ab-e302acdb2d28"}, + "order_by": {"cost": true} + } + + 15. Get total cost by sku for given goog_pipelines_worker, order by cost DESC: + + { + "fields": ["goog_pipelines_worker"], + "start_date": "2023-11-10", + "end_date": "2023-11-10", + "order_by": {"cost": true} + } + + 16. Get total cost by sku for given wdl_task_name, order by cost DESC: + + { + "fields": ["wdl_task_name"], + "start_date": "2023-11-10", + "end_date": "2023-11-10", + "order_by": {"cost": true} + } + + 17. Get total cost by sku for provided ID, which can be any of + [ar_guid, batch_id, sequencing_group or cromwell_workflow_id], + order by cost DESC: + + { + "fields": ["sku", "ar_guid", "batch_id", "sequencing_group", "cromwell_workflow_id"], + "start_date": "2023-11-01", + "end_date": "2023-11-30", + "filters": { + "ar_guid": "855a6153-033c-4398-8000-46ed74c02fe8", + "batch_id": "429518", + "sequencing_group": "cpg246751", + "cromwell_workflow_id": "cromwell-e252f430-4143-47ec-a9c0-5f7face1b296" + }, + "filters_op": "OR", + "order_by": {"cost": true} + } + + 18. Get weekly total cost by sku for selected cost_category, order by day ASC: + + { + "fields": ["sku"], + "start_date": "2022-11-01", + "end_date": "2023-12-07", + "filters": { + "cost_category": "Cloud Storage" + }, + "order_by": {"day": false}, + "time_periods": "week" + } + + """ + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_total_cost(query) - return records + return [BillingTotalCostRecord.from_json(record) for record in records] @router.get( @@ -358,20 +521,18 @@ async def get_total_cost( async def get_running_costs( field: BillingColumn, invoice_month: str | None = None, - source: str | None = None, + source: BillingSource | None = None, author: str = get_author, ) -> list[BillingCostBudgetRecord]: """ Get running cost for specified fields in database - e.g. fields = ['gcp_project', 'topic'] + e.g. fields = ['gcp_project', 'topic', 'wdl_task_names', 'cromwell_sub_workflow_name', 'compute_category'] """ - # TODO replace alru_cache with async-cache? # so we can skip author for caching? # pip install async-cache # @AsyncTTL(time_to_live=BILLING_CACHE_RESPONSE_TTL, maxsize=1024, skip_args=2) - connection = BqConnection(author) - billing_layer = BillingLayer(connection) + billing_layer = _get_billing_layer_from(author) records = await billing_layer.get_running_cost(field, invoice_month, source) return records diff --git a/api/routes/project.py b/api/routes/project.py index 5771dfc91..559938485 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -12,14 +12,14 @@ @router.get('/all', operation_id='getAllProjects', response_model=List[Project]) async def get_all_projects(connection=get_projectless_db_connection): """Get list of projects""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) return await ptable.get_all_projects(author=connection.author) @router.get('/', operation_id='getMyProjects', response_model=List[str]) async def get_my_projects(connection=get_projectless_db_connection): """Get projects I have access to""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) projects = await ptable.get_projects_accessible_by_user( author=connection.author, readonly=True ) @@ -36,7 +36,7 @@ async def create_project( """ Create a new project """ - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) pid = await ptable.create_project( project_name=name, dataset_name=dataset, @@ -56,7 +56,7 @@ async def create_project( @router.get('/seqr/all', operation_id='getSeqrProjects') async def get_seqr_projects(connection: Connection = get_projectless_db_connection): """Get SM projects that should sync to seqr""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) return await ptable.get_seqr_projects() @@ -67,7 +67,7 @@ async def update_project( connection: Connection = get_projectless_db_connection, ): """Update a project by project name""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) return await ptable.update_project( project_name=project, update=project_update_model, author=connection.author ) @@ -84,7 +84,7 @@ async def delete_project_data( Can optionally delete the project itself. Requires READ access + project-creator permissions """ - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) p_obj = await ptable.get_and_check_access_to_project_for_name( user=connection.author, project_name=project, readonly=False ) @@ -106,7 +106,7 @@ async def update_project_members( Update project members for specific read / write group. Not that this is protected by access to a specific access group """ - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) await ptable.set_group_members( group_name=ptable.get_project_group_name(project, readonly=readonly), members=members, diff --git a/api/routes/sample.py b/api/routes/sample.py index 1b4bf76d9..646a2dc7b 100644 --- a/api/routes/sample.py +++ b/api/routes/sample.py @@ -130,7 +130,7 @@ async def get_samples( """ st = SampleLayer(connection) - pt = ProjectPermissionsTable(connection.connection) + pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if project_ids: pids = await pt.get_project_ids_from_names_and_user( diff --git a/api/routes/web.py b/api/routes/web.py index 38255514b..789e174d0 100644 --- a/api/routes/web.py +++ b/api/routes/web.py @@ -77,7 +77,7 @@ async def search_by_keyword(keyword: str, connection=get_projectless_db_connecti that you are a part of (automatically). """ # raise ValueError("Test") - pt = ProjectPermissionsTable(connection.connection) + pt = ProjectPermissionsTable(connection) projects = await pt.get_projects_accessible_by_user( connection.author, readonly=True ) diff --git a/api/server.py b/api/server.py index 62325807e..652aa4322 100644 --- a/api/server.py +++ b/api/server.py @@ -2,25 +2,24 @@ import time import traceback -from fastapi import FastAPI, Request, HTTPException, APIRouter +from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles -from fastapi.middleware.cors import CORSMiddleware from pydantic import ValidationError from starlette.responses import FileResponse -from db.python.connect import SMConnections -from db.python.tables.project import is_all_access -from db.python.utils import get_logger - from api import routes -from api.utils import get_openapi_schema_func -from api.utils.exceptions import determine_code_from_error from api.graphql.schema import MetamistGraphQLRouter # type: ignore from api.settings import PROFILE_REQUESTS, SKIP_DATABASE_CONNECTION +from api.utils import get_openapi_schema_func +from api.utils.exceptions import determine_code_from_error +from db.python.connect import SMConnections +from db.python.tables.project import is_all_access +from db.python.utils import get_logger # This tag is automatically updated by bump2version -_VERSION = '6.5.0' +_VERSION = '6.6.2' logger = get_logger() @@ -141,11 +140,11 @@ async def exception_handler(request: Request, e: Exception): cors_middleware = middlewares[0] request_origin = request.headers.get('origin', '') - if cors_middleware and '*' in cors_middleware.options['allow_origins']: + if cors_middleware and '*' in cors_middleware.options['allow_origins']: # type: ignore response.headers['Access-Control-Allow-Origin'] = '*' elif ( cors_middleware - and request_origin in cors_middleware.options['allow_origins'] + and request_origin in cors_middleware.options['allow_origins'] # type: ignore ): response.headers['Access-Control-Allow-Origin'] = request_origin @@ -169,9 +168,10 @@ async def exception_handler(request: Request, e: Exception): if __name__ == '__main__': - import uvicorn import logging + import uvicorn + logging.getLogger('watchfiles').setLevel(logging.WARNING) logging.getLogger('watchfiles.main').setLevel(logging.WARNING) diff --git a/api/settings.py b/api/settings.py index ab92c717c..98d11062f 100644 --- a/api/settings.py +++ b/api/settings.py @@ -42,10 +42,11 @@ BQ_AGGREG_EXT_VIEW = os.getenv('SM_GCP_BQ_AGGREG_EXT_VIEW') BQ_BUDGET_VIEW = os.getenv('SM_GCP_BQ_BUDGET_VIEW') BQ_GCP_BILLING_VIEW = os.getenv('SM_GCP_BQ_BILLING_VIEW') +BQ_BATCHES_VIEW = os.getenv('SM_GCP_BQ_BATCHES_VIEW') # This is to optimise BQ queries, DEV table has data only for Mar 2023 -BQ_DAYS_BACK_OPTIMAL = 30 # Look back 30 days for optimal query -BILLING_CACHE_RESPONSE_TTL = 3600 # 1 Hour +BQ_DAYS_BACK_OPTIMAL = 30 # Look back 30 days for optimal query +BILLING_CACHE_RESPONSE_TTL = 3600 # 1 Hour def get_default_user() -> str | None: diff --git a/api/utils/db.py b/api/utils/db.py index 7a22944fc..5d88a2b1b 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -1,22 +1,21 @@ -from os import getenv import logging -from typing import Optional +from os import getenv from fastapi import Depends, HTTPException, Request -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from google.auth.transport import requests from google.oauth2 import id_token from api.settings import get_default_user from api.utils.gcp import email_from_id_token -from db.python.connect import SMConnections, Connection +from db.python.connect import Connection, SMConnections from db.python.gcp_connect import BqConnection, PubSubConnection - +from db.python.tables.project import ProjectPermissionsTable EXPECTED_AUDIENCE = getenv('SM_OAUTHAUDIENCE') -def get_jwt_from_request(request: Request) -> Optional[str]: +def get_jwt_from_request(request: Request) -> str | None: """ Get google JWT value, capture it like this instead of using x_goog_iap_jwt_assertion = Header(None) @@ -25,11 +24,22 @@ def get_jwt_from_request(request: Request) -> Optional[str]: return request.headers.get('x-goog-iap-jwt-assertion') +def get_ar_guid(request: Request) -> str | None: + """Get sm-ar-guid from the headers to provide with requests""" + return request.headers.get('sm-ar-guid') + + +def get_on_behalf_of(request: Request) -> str | None: + """ + Get sm-on-behalf-of if there are requests that were performed on behalf of + someone else (some automated process) + """ + return request.headers.get('sm-on-behalf-of') + + def authenticate( - token: Optional[HTTPAuthorizationCredentials] = Depends( - HTTPBearer(auto_error=False) - ), - x_goog_iap_jwt_assertion: Optional[str] = Depends(get_jwt_from_request), + token: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False)), + x_goog_iap_jwt_assertion: str | None = Depends(get_jwt_from_request), ) -> str: """ If a token (OR Google IAP auth jwt) is provided, @@ -51,30 +61,58 @@ def authenticate( logging.info(f'Using {default_user} as authenticated user') return default_user - raise HTTPException(status_code=401, detail=f'Not authenticated :(') + raise HTTPException(status_code=401, detail='Not authenticated :(') async def dependable_get_write_project_connection( - project: str, author: str = Depends(authenticate) + project: str, + request: Request, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), + on_behalf_of: str | None = Depends(get_on_behalf_of), ) -> Connection: """FastAPI handler for getting connection WITH project""" - return await SMConnections.get_connection( - project_name=project, author=author, readonly=False + meta = {'path': request.url.path} + if request.client: + meta['ip'] = request.client.host + return await ProjectPermissionsTable.get_project_connection( + project_name=project, + author=author, + readonly=False, + ar_guid=ar_guid, + on_behalf_of=on_behalf_of, + meta=meta, ) async def dependable_get_readonly_project_connection( - project: str, author: str = Depends(authenticate) + project: str, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), ) -> Connection: """FastAPI handler for getting connection WITH project""" - return await SMConnections.get_connection( - project_name=project, author=author, readonly=True + return await ProjectPermissionsTable.get_project_connection( + project_name=project, + author=author, + readonly=True, + on_behalf_of=None, + ar_guid=ar_guid, ) -async def dependable_get_connection(author: str = Depends(authenticate)): +async def dependable_get_connection( + request: Request, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), +): """FastAPI handler for getting connection withOUT project""" - return await SMConnections.get_connection_no_project(author) + meta = {'path': request.url.path} + if request.client: + meta['ip'] = request.client.host + + return await SMConnections.get_connection_no_project( + author, ar_guid=ar_guid, meta=meta + ) async def dependable_get_bq_connection(author: str = Depends(authenticate)): diff --git a/api/utils/exceptions.py b/api/utils/exceptions.py index c418ef6cd..ff24b6f6e 100644 --- a/api/utils/exceptions.py +++ b/api/utils/exceptions.py @@ -1,5 +1,4 @@ -from db.python.connect import NotFoundError -from db.python.utils import Forbidden +from db.python.utils import Forbidden, NotFoundError def determine_code_from_error(e): diff --git a/db/project.xml b/db/project.xml index dde025226..d7678e853 100644 --- a/db/project.xml +++ b/db/project.xml @@ -862,4 +862,268 @@ + + SET @@system_versioning_alter_history = 1; + + + + + + + + + + + + + + + + + + + + + + + + + ALTER TABLE `audit_log` ADD SYSTEM VERSIONING; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ALTER TABLE analysis CHANGE author author VARCHAR(255) NULL; + ALTER TABLE assay CHANGE author author VARCHAR(255) NULL; + ALTER TABLE assay_external_id CHANGE author author VARCHAR(255) NULL; + ALTER TABLE cohort CHANGE author author VARCHAR(255) NULL; + ALTER TABLE family CHANGE author author VARCHAR(255) NULL; + ALTER TABLE family_participant CHANGE author author VARCHAR(255) NULL; + ALTER TABLE group_member CHANGE author author VARCHAR(255) NULL DEFAULT NULL; + ALTER TABLE participant CHANGE author author VARCHAR(255) NULL; + ALTER TABLE participant_phenotypes CHANGE author author VARCHAR(255) NULL; + ALTER TABLE project CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sample CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sample_sequencing CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sample_sequencing_eid CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sequencing_group CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sequencing_group_assay CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sequencing_group_external_id CHANGE author author VARCHAR(255) NULL; + diff --git a/db/python/connect.py b/db/python/connect.py index e27e74293..971058f8f 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -1,19 +1,18 @@ -# pylint: disable=unused-import +# pylint: disable=unused-import,too-many-instance-attributes # flake8: noqa """ Code for connecting to Postgres database """ import abc +import asyncio import json import logging import os -from typing import Optional import databases from api.settings import LOG_DATABASE_QUERIES -from db.python.tables.project import ProjectPermissionsTable -from db.python.utils import InternalError, NoOpAenter, NotFoundError +from db.python.utils import InternalError logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -28,9 +27,9 @@ 'assay', 'sequencing_group_assay', 'analysis_sequencing_group', + 'analysis_sample', 'assay_external_id', 'sequencing_group_external_id', - 'analysis_sequencing_group', 'family', 'family_participant', 'participant_phenotypes', @@ -44,12 +43,48 @@ class Connection: def __init__( self, connection: databases.Database, - project: Optional[int], + project: int | None, author: str, + on_behalf_of: str | None, + readonly: bool, + ar_guid: str | None, + meta: dict[str, str] | None = None, ): self.connection: databases.Database = connection - self.project: Optional[int] = project + self.project: int | None = project self.author: str = author + self.on_behalf_of: str | None = on_behalf_of + self.readonly: bool = readonly + self.ar_guid: str | None = ar_guid + self.meta = meta + + self._audit_log_id: int | None = None + self._audit_log_lock = asyncio.Lock() + + async def audit_log_id(self): + """Get audit_log ID for write operations, cached per connection""" + if self.readonly: + raise InternalError( + 'Trying to get a audit_log ID, but not a write connection' + ) + + async with self._audit_log_lock: + if not self._audit_log_id: + # pylint: disable=import-outside-toplevel + # make this import here, otherwise we'd have a circular import + from db.python.tables.audit_log import AuditLogTable + + at = AuditLogTable(self) + self._audit_log_id = await at.create_audit_log( + author=self.author, + on_behalf_of=self.on_behalf_of, + ar_guid=self.ar_guid, + comment=None, + project=self.project, + meta=self.meta, + ) + + return self._audit_log_id def assert_requires_project(self): """Assert the project is set, or return an exception""" @@ -132,11 +167,7 @@ def get_connection_string(self): class SMConnections: """Contains useful functions for connecting to the database""" - # _connected = False - # _connections: Dict[str, databases.Database] = {} - # _admin_db: databases.Database = None - - _credentials: Optional[DatabaseConfiguration] = None + _credentials: DatabaseConfiguration | None = None @staticmethod def _get_config(): @@ -175,7 +206,10 @@ async def disconnect(): return False @staticmethod - async def _get_made_connection(): + async def get_made_connection(): + """ + Makes a new connection to the database on each call + """ credentials = SMConnections._get_config() if credentials is None: @@ -190,70 +224,24 @@ async def _get_made_connection(): return conn @staticmethod - async def get_connection(*, author: str, project_name: str, readonly: bool): - """Get a db connection from a project and user""" - # maybe it makes sense to perform permission checks here too - logger.debug(f'Authenticate connection to {project_name} with {author!r}') - - conn = await SMConnections._get_made_connection() - pt = ProjectPermissionsTable(connection=conn) - - project = await pt.get_and_check_access_to_project_for_name( - user=author, project_name=project_name, readonly=readonly - ) - - return Connection(connection=conn, author=author, project=project.id) - - @staticmethod - async def get_connection_no_project(author: str): + async def get_connection_no_project( + author: str, ar_guid: str, meta: dict[str, str] + ): """Get a db connection from a project and user""" # maybe it makes sense to perform permission checks here too logger.debug(f'Authenticate no-project connection with {author!r}') - conn = await SMConnections._get_made_connection() + conn = await SMConnections.get_made_connection() # we don't authenticate project-less connection, but rely on the # the endpoint to validate the resources - return Connection(connection=conn, author=author, project=None) - - -class DbBase: - """Base class for table subclasses""" - - @classmethod - async def from_project(cls, project, author, readonly: bool): - """Create the Db object from a project with user details""" - return cls( - connection=await SMConnections.get_connection( - project_name=project, author=author, readonly=readonly - ), + return Connection( + connection=conn, + author=author, + project=None, + on_behalf_of=None, + ar_guid=ar_guid, + readonly=False, + meta=meta, ) - - def __init__(self, connection: Connection): - if connection is None: - raise InternalError( - f'No connection was provided to the table {self.__class__.__name__!r}' - ) - if not isinstance(connection, Connection): - raise InternalError( - f'Expected connection type Connection, received {type(connection)}, ' - f'did you mean to call self._connection?' - ) - - self._connection = connection - self.connection: databases.Database = connection.connection - self.author = connection.author - self.project = connection.project - - if self.author is None: - raise InternalError(f'Must provide author to {self.__class__.__name__}') - - # piped from the connection - - @staticmethod - def escape_like_term(query: str): - """ - Escape meaningful keys when using LIKE with a user supplied input - """ - return query.replace('%', '\\%').replace('_', '\\_') diff --git a/db/python/enum_tables/enums.py b/db/python/enum_tables/enums.py index 113daf3d6..1313b0cb4 100644 --- a/db/python/enum_tables/enums.py +++ b/db/python/enum_tables/enums.py @@ -4,7 +4,7 @@ from async_lru import alru_cache -from db.python.connect import DbBase +from db.python.tables.base import DbBase table_name_matcher = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') @@ -58,12 +58,14 @@ async def insert(self, value: str): Insert a new type """ _query = f""" - INSERT INTO {self._get_table_name()} (id, name) - VALUES (:name, :name) - ON DUPLICATE KEY UPDATE name = :name + INSERT INTO {self._get_table_name()} (id, name, audit_log_id) + VALUES (:name, :name, :audit_log_id) + ON DUPLICATE KEY UPDATE name = :name, audit_log_id = :audit_log_id """ - await self.connection.execute(_query, {'name': value.lower()}) + await self.connection.execute( + _query, {'name': value.lower(), 'audit_log_id': await self.audit_log_id()} + ) # clear the cache so results are up-to-date self.get.cache_clear() # pylint: disable=no-member return value diff --git a/db/python/layers/__init__.py b/db/python/layers/__init__.py index 9c9673649..04dff8dfe 100644 --- a/db/python/layers/__init__.py +++ b/db/python/layers/__init__.py @@ -1,7 +1,9 @@ from db.python.layers.analysis import AnalysisLayer from db.python.layers.assay import AssayLayer +from db.python.layers.audit_log import AuditLogLayer from db.python.layers.base import BaseLayer from db.python.layers.cohort import CohortLayer +from db.python.layers.billing import BillingLayer from db.python.layers.family import FamilyLayer from db.python.layers.participant import ParticipantLayer from db.python.layers.sample import SampleLayer diff --git a/db/python/layers/analysis.py b/db/python/layers/analysis.py index e3bc1dd1d..f54c91eaa 100644 --- a/db/python/layers/analysis.py +++ b/db/python/layers/analysis.py @@ -1,5 +1,5 @@ -from collections import defaultdict import datetime +from collections import defaultdict from typing import Any from api.utils import group_by @@ -7,18 +7,19 @@ from db.python.layers.base import BaseLayer from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.analysis import AnalysisFilter, AnalysisTable -from db.python.tables.project import ProjectId from db.python.tables.sample import SampleTable from db.python.tables.sequencing_group import SequencingGroupFilter from db.python.utils import GenericFilter, get_logger from models.enums import AnalysisStatus from models.models import ( AnalysisInternal, + AuditLogInternal, ProportionalDateModel, ProportionalDateProjectModel, ProportionalDateTemporalMethod, SequencingGroupInternal, ) +from models.models.project import ProjectId from models.models.sequencing_group import SequencingGroupInternalId ES_ANALYSIS_OBJ_INTRO_DATE = datetime.date(2022, 6, 21) @@ -530,12 +531,15 @@ async def get_sgs_added_by_day_by_es_indices( return by_day + async def get_audit_logs_by_analysis_ids(self, analysis_ids: list[int]) -> dict[int, list[AuditLogInternal]]: + """Get audit logs for analysis IDs""" + return await self.at.get_audit_log_for_analysis_ids(analysis_ids) + # CREATE / UPDATE async def create_analysis( self, analysis: AnalysisInternal, - author: str = None, project: ProjectId = None, ) -> int: """Create a new analysis""" @@ -546,7 +550,6 @@ async def create_analysis( meta=analysis.meta, output=analysis.output, active=analysis.active, - author=author, project=project, ) @@ -570,7 +573,6 @@ async def update_analysis( status: AnalysisStatus, meta: dict[str, Any] = None, output: str | None = None, - author: str | None = None, check_project_id=True, ): """ @@ -587,7 +589,6 @@ async def update_analysis( status=status, meta=meta, output=output, - author=author, ) async def get_analysis_runner_log( diff --git a/db/python/layers/assay.py b/db/python/layers/assay.py index 32dc9aabb..0e512e6e7 100644 --- a/db/python/layers/assay.py +++ b/db/python/layers/assay.py @@ -1,7 +1,7 @@ # pylint: disable=too-many-arguments from typing import Any -from db.python.connect import NoOpAenter +from db.python.utils import NoOpAenter from db.python.layers.base import BaseLayer, Connection from db.python.tables.assay import AssayTable, AssayFilter from db.python.tables.sample import SampleTable diff --git a/db/python/layers/audit_log.py b/db/python/layers/audit_log.py new file mode 100644 index 000000000..1cc351ecc --- /dev/null +++ b/db/python/layers/audit_log.py @@ -0,0 +1,30 @@ +from db.python.layers.base import BaseLayer, Connection +from db.python.tables.audit_log import AuditLogTable +from models.models.audit_log import AuditLogId, AuditLogInternal + + +class AuditLogLayer(BaseLayer): + """Layer for more complex sample logic""" + + def __init__(self, connection: Connection): + super().__init__(connection) + self.alayer: AuditLogTable = AuditLogTable(connection) + + # GET + async def get_for_ids( + self, ids: list[AuditLogId], check_project_id=True + ) -> list[AuditLogInternal]: + """Query for samples""" + if not ids: + return [] + + logs = await self.alayer.get_audit_logs_for_ids(ids) + if check_project_id: + projects = {log.auth_project for log in logs} + await self.ptable.check_access_to_project_ids( + user=self.author, project_ids=projects, readonly=True + ) + + return logs + + # don't put the create here, I don't want it to be public diff --git a/db/python/layers/base.py b/db/python/layers/base.py index 3b8f282eb..28c7575a7 100644 --- a/db/python/layers/base.py +++ b/db/python/layers/base.py @@ -1,4 +1,4 @@ -from db.python.connect import Connection, ProjectPermissionsTable +from db.python.tables.project import Connection, ProjectPermissionsTable class BaseLayer: @@ -6,7 +6,7 @@ class BaseLayer: def __init__(self, connection: Connection): self.connection = connection - self.ptable = ProjectPermissionsTable(self.connection.connection) + self.ptable = ProjectPermissionsTable(connection) @property def author(self): diff --git a/db/python/layers/billing.py b/db/python/layers/billing.py index 93ce3cfc7..737be6857 100644 --- a/db/python/layers/billing.py +++ b/db/python/layers/billing.py @@ -1,49 +1,71 @@ -import re - -from typing import Any -from datetime import datetime -from collections import Counter, defaultdict -from google.cloud import bigquery - +from db.python.layers.bq_base import BqBaseLayer +from db.python.tables.bq.billing_ar_batch import BillingArBatchTable +from db.python.tables.bq.billing_daily import BillingDailyTable +from db.python.tables.bq.billing_daily_extended import BillingDailyExtendedTable +from db.python.tables.bq.billing_gcp_daily import BillingGcpDailyTable +from db.python.tables.bq.billing_raw import BillingRawTable +from models.enums import BillingSource, BillingTimeColumn, BillingTimePeriods from models.models import ( - BillingRowRecord, - BillingTotalCostRecord, - BillingTotalCostQueryModel, BillingColumn, BillingCostBudgetRecord, + BillingHailBatchCostRecord, + BillingTotalCostQueryModel, ) -from db.python.gcp_connect import BqDbBase -from db.python.layers.bq_base import BqBaseLayer -from db.python.tables.billing import BillingFilter - -from api.settings import ( - BQ_DAYS_BACK_OPTIMAL, - BQ_AGGREG_VIEW, - BQ_AGGREG_RAW, - BQ_AGGREG_EXT_VIEW, - BQ_BUDGET_VIEW, - BQ_GCP_BILLING_VIEW, -) -from api.utils.dates import get_invoice_month_range, reformat_datetime - - -def abbrev_cost_category(cost_category: str) -> str: - """abbreviate cost category""" - return 'S' if cost_category == 'Cloud Storage' else 'C' - class BillingLayer(BqBaseLayer): """Billing layer""" + def table_factory( + self, + source: BillingSource, + fields: list[BillingColumn] | None = None, + filters: dict[BillingColumn, str | list | dict] | None = None, + ) -> ( + BillingDailyTable + | BillingDailyExtendedTable + | BillingGcpDailyTable + | BillingRawTable + ): + """Get billing table object based on source and fields""" + if source == BillingSource.GCP_BILLING: + return BillingGcpDailyTable(self.connection) + if source == BillingSource.RAW: + return BillingRawTable(self.connection) + + # check if any of the fields is in the extended columns + if fields: + used_extended_cols = [ + f + for f in fields + if f in BillingColumn.extended_cols() and BillingColumn.can_group_by(f) + ] + if used_extended_cols: + # there is a field from extended daily table + return BillingDailyExtendedTable(self.connection) + + # check if any of the filters is in the extended columns + if filters: + used_extended_cols = [ + f + for f in filters + if f in BillingColumn.extended_cols() and BillingColumn.can_group_by(f) + ] + if used_extended_cols: + # there is a field from extended daily table + return BillingDailyExtendedTable(self.connection) + + # by default look at the daily table + return BillingDailyTable(self.connection) + async def get_gcp_projects( self, ) -> list[str] | None: """ Get All GCP projects in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_gcp_projects() + billing_table = BillingGcpDailyTable(self.connection) + return await billing_table.get_gcp_projects() async def get_topics( self, @@ -51,8 +73,8 @@ async def get_topics( """ Get All topics in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_topics() + billing_table = BillingDailyTable(self.connection) + return await billing_table.get_topics() async def get_cost_categories( self, @@ -60,8 +82,8 @@ async def get_cost_categories( """ Get All service description / cost categories in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_cost_categories() + billing_table = BillingDailyTable(self.connection) + return await billing_table.get_cost_categories() async def get_skus( self, @@ -71,8 +93,8 @@ async def get_skus( """ Get All SKUs in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_skus(limit, offset) + billing_table = BillingDailyTable(self.connection) + return await billing_table.get_skus(limit, offset) async def get_datasets( self, @@ -80,8 +102,8 @@ async def get_datasets( """ Get All datasets in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_extended_values('dataset') + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('dataset') async def get_stages( self, @@ -89,8 +111,8 @@ async def get_stages( """ Get All stages in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_extended_values('stage') + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('stage') async def get_sequencing_types( self, @@ -98,8 +120,8 @@ async def get_sequencing_types( """ Get All sequencing_types in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_extended_values('sequencing_type') + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('sequencing_type') async def get_sequencing_groups( self, @@ -107,865 +129,172 @@ async def get_sequencing_groups( """ Get All sequencing_groups in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_extended_values('sequencing_group') + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('sequencing_group') - async def get_invoice_months( + async def get_compute_categories( self, ) -> list[str] | None: """ - Get All invoice months in database + Get All compute_category values in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_invoice_months() + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('compute_category') - async def query( + async def get_cromwell_sub_workflow_names( self, - _filter: BillingFilter, - limit: int = 10, - ) -> list[BillingRowRecord] | None: + ) -> list[str] | None: """ - Get Billing record for the given gilter + Get All cromwell_sub_workflow_name values in database """ - billing_db = BillingDb(self.connection) - return await billing_db.query(_filter, limit) + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('cromwell_sub_workflow_name') - async def get_total_cost( + async def get_wdl_task_names( self, - query: BillingTotalCostQueryModel, - ) -> list[BillingTotalCostRecord] | None: + ) -> list[str] | None: """ - Get Total cost of selected fields for requested time interval + Get All wdl_task_name values in database """ - billing_db = BillingDb(self.connection) - return await billing_db.get_total_cost(query) + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('wdl_task_name') - async def get_running_cost( + async def get_invoice_months( self, - field: BillingColumn, - invoice_month: str | None = None, - source: str | None = None, - ) -> list[BillingCostBudgetRecord]: - """ - Get Running costs including monthly budget - """ - billing_db = BillingDb(self.connection) - return await billing_db.get_running_cost(field, invoice_month, source) - - -class BillingDb(BqDbBase): - """Db layer for billing related routes""" - - async def get_gcp_projects(self): - """Get all GCP projects in database""" - - # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query - # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL - # this part_time > filter is to limit the amount of data scanned, - # saving cost for running BQ - _query = f""" - SELECT DISTINCT gcp_project - FROM `{BQ_GCP_BILLING_VIEW}` - WHERE part_time > TIMESTAMP_ADD( - CURRENT_TIMESTAMP(), INTERVAL @days DAY - ) - AND gcp_project IS NOT NULL - ORDER BY gcp_project ASC; - """ - - job_config = bigquery.QueryJobConfig( - query_parameters=[ - bigquery.ScalarQueryParameter( - 'days', 'INT64', -int(BQ_DAYS_BACK_OPTIMAL) - ), - ] - ) - - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - if query_job_result: - return [str(dict(row)['gcp_project']) for row in query_job_result] - - # return empty list if no record found - return [] - - async def get_topics(self): - """Get all topics in database""" - - # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query - # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL - # this day > filter is to limit the amount of data scanned, - # saving cost for running BQ - # aggregated views are partitioned by day - _query = f""" - SELECT DISTINCT topic - FROM `{BQ_AGGREG_VIEW}` - WHERE day > TIMESTAMP_ADD( - CURRENT_TIMESTAMP(), INTERVAL @days DAY - ) - ORDER BY topic ASC; + ) -> list[str] | None: """ - - job_config = bigquery.QueryJobConfig( - query_parameters=[ - bigquery.ScalarQueryParameter( - 'days', 'INT64', -int(BQ_DAYS_BACK_OPTIMAL) - ), - ] - ) - - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - if query_job_result: - return [str(dict(row)['topic']) for row in query_job_result] - - # return empty list if no record found - return [] - - async def get_invoice_months(self): - """Get all invoice months in database""" - - _query = f""" - SELECT DISTINCT FORMAT_DATE("%Y%m", day) as invoice_month - FROM `{BQ_AGGREG_VIEW}` - WHERE EXTRACT(day from day) = 1 - ORDER BY invoice_month DESC; - """ - - query_job_result = list(self._connection.connection.query(_query).result()) - if query_job_result: - return [str(dict(row)['invoice_month']) for row in query_job_result] - - # return empty list if no record found - return [] - - async def get_cost_categories(self): - """Get all service description in database""" - - # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query - # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL - # this day > filter is to limit the amount of data scanned, - # saving cost for running BQ - # aggregated views are partitioned by day - _query = f""" - SELECT DISTINCT cost_category - FROM `{BQ_AGGREG_VIEW}` - WHERE day > TIMESTAMP_ADD( - CURRENT_TIMESTAMP(), INTERVAL @days DAY - ) - ORDER BY cost_category ASC; + Get All invoice months in database """ + billing_table = BillingDailyTable(self.connection) + return await billing_table.get_invoice_months() - job_config = bigquery.QueryJobConfig( - query_parameters=[ - bigquery.ScalarQueryParameter( - 'days', 'INT64', -int(BQ_DAYS_BACK_OPTIMAL) - ), - ] - ) - - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - if query_job_result: - return [str(dict(row)['cost_category']) for row in query_job_result] - - # return empty list if no record found - return [] - - async def get_skus( + async def get_namespaces( self, - limit: int | None = None, - offset: int | None = None, - ): - """Get all SKUs in database""" - - # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query - # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL - # this day > filter is to limit the amount of data scanned, - # saving cost for running BQ - # aggregated views are partitioned by day - _query = f""" - SELECT DISTINCT sku - FROM `{BQ_AGGREG_VIEW}` - WHERE day > TIMESTAMP_ADD( - CURRENT_TIMESTAMP(), INTERVAL @days DAY - ) - ORDER BY sku ASC - """ - - # append LIMIT and OFFSET if present - if limit: - _query += ' LIMIT @limit_val' - if offset: - _query += ' OFFSET @offset_val' - - job_config = bigquery.QueryJobConfig( - query_parameters=[ - bigquery.ScalarQueryParameter( - 'days', 'INT64', -int(BQ_DAYS_BACK_OPTIMAL) - ), - bigquery.ScalarQueryParameter('limit_val', 'INT64', limit), - bigquery.ScalarQueryParameter('offset_val', 'INT64', offset), - ] - ) - - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - if query_job_result: - return [str(dict(row)['sku']) for row in query_job_result] - - # return empty list if no record found - return [] - - async def get_extended_values(self, field: str): + ) -> list[str] | None: """ - Get all extended values in database, - e.g. dataset, stage, sequencing_type or sequencing_group + Get All namespaces values in database """ - - # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query - # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL - # this day > filter is to limit the amount of data scanned, - # saving cost for running BQ - # aggregated views are partitioned by day - _query = f""" - SELECT DISTINCT {field} - FROM `{BQ_AGGREG_EXT_VIEW}` - WHERE {field} IS NOT NULL - AND day > TIMESTAMP_ADD( - CURRENT_TIMESTAMP(), INTERVAL @days DAY - ) - ORDER BY 1 ASC; - """ - - job_config = bigquery.QueryJobConfig( - query_parameters=[ - bigquery.ScalarQueryParameter( - 'days', 'INT64', -int(BQ_DAYS_BACK_OPTIMAL) - ), - ] - ) - - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - if query_job_result: - return [str(dict(row)[field]) for row in query_job_result] - - # return empty list if no record found - return [] - - async def query( - self, - filter_: BillingFilter, - limit: int = 10, - ) -> list[BillingRowRecord] | None: - """Get Billing record from BQ""" - - # TODO: THis function is not going to be used most likely - # get_total_cost will replace it - - # cost of this BQ is 30MB on DEV, - # DEV is partition by day and date is required filter params, - # cost is aprox per query: AU$ 0.000023 per query - - required_fields = [ - filter_.date, - ] - - if not any(required_fields): - raise ValueError('Must provide date to filter on') - - # construct filters - filters = [] - query_parameters = [] - - if filter_.topic: - filters.append('topic IN UNNEST(@topic)') - query_parameters.append( - bigquery.ArrayQueryParameter('topic', 'STRING', filter_.topic.in_), - ) - - if filter_.date: - filters.append('DATE_TRUNC(usage_end_time, DAY) = TIMESTAMP(@date)') - query_parameters.append( - bigquery.ScalarQueryParameter('date', 'STRING', filter_.date.eq), - ) - - if filter_.cost_category: - filters.append('service.description IN UNNEST(@cost_category)') - query_parameters.append( - bigquery.ArrayQueryParameter( - 'cost_category', 'STRING', filter_.cost_category.in_ - ), - ) - - filter_str = 'WHERE ' + ' AND '.join(filters) if filters else '' - - _query = f""" - SELECT id, topic, service, sku, usage_start_time, usage_end_time, project, - labels, export_time, cost, currency, currency_conversion_rate, invoice, cost_type - FROM `{BQ_AGGREG_RAW}` - {filter_str} - """ - if limit: - _query += ' LIMIT @limit_val' - query_parameters.append( - bigquery.ScalarQueryParameter('limit_val', 'INT64', limit) - ) - - job_config = bigquery.QueryJobConfig(query_parameters=query_parameters) - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - - if query_job_result: - return [BillingRowRecord.from_json(dict(row)) for row in query_job_result] - - raise ValueError('No record found') + billing_table = BillingDailyExtendedTable(self.connection) + return await billing_table.get_extended_values('namespace') async def get_total_cost( self, query: BillingTotalCostQueryModel, - ) -> list[BillingTotalCostRecord] | None: - """ - Get Total cost of selected fields for requested time interval from BQ view + ) -> list[dict] | None: """ - if not query.start_date or not query.end_date or not query.fields: - raise ValueError('Date and Fields are required') - - extended_cols = BillingColumn.extended_cols() - - # by default look at the normal view - if query.source == 'gcp_billing': - view_to_use = BQ_GCP_BILLING_VIEW - else: - view_to_use = BQ_AGGREG_VIEW - - columns = [] - for field in query.fields: - col_name = str(field.value) - if col_name == 'cost': - # skip the cost field as it will be always present - continue - - if col_name in extended_cols: - # if one of the extended columns is needed, the view has to be extended - view_to_use = BQ_AGGREG_EXT_VIEW - - columns.append(col_name) - - fields_selected = ','.join(columns) - - # construct filters - filters = [] - query_parameters = [] - - filters.append('day >= TIMESTAMP(@start_date)') - query_parameters.append( - bigquery.ScalarQueryParameter('start_date', 'STRING', query.start_date) - ) - - filters.append('day <= TIMESTAMP(@end_date)') - query_parameters.append( - bigquery.ScalarQueryParameter('end_date', 'STRING', query.end_date) - ) - - if query.source == 'gcp_billing': - # BQ_GCP_BILLING_VIEW view is partitioned by different field - # BQ has limitation, materialized view can only by partition by base table - # partition or its subset, in our case _PARTITIONTIME - # (part_time field in the view) - # We are querying by day, - # which can be up to a week behind regarding _PARTITIONTIME - filters.append('part_time >= TIMESTAMP(@start_date)') - filters.append( - 'part_time <= TIMESTAMP_ADD(TIMESTAMP(@end_date), INTERVAL 7 DAY)' - ) - - if query.filters: - for filter_key, filter_value in query.filters.items(): - col_name = str(filter_key.value) - filters.append(f'{col_name} = @{col_name}') - query_parameters.append( - bigquery.ScalarQueryParameter(col_name, 'STRING', filter_value) - ) - if col_name in extended_cols: - # if one of the extended columns is needed, - # the view has to be extended - view_to_use = BQ_AGGREG_EXT_VIEW - - filter_str = 'WHERE ' + ' AND '.join(filters) if filters else '' - - # construct order by - order_by_cols = [] - if query.order_by: - for order_field, reverse in query.order_by.items(): - col_name = str(order_field.value) - col_order = 'DESC' if reverse else 'ASC' - order_by_cols.append(f'{col_name} {col_order}') - - order_by_str = f'ORDER BY {",".join(order_by_cols)}' if order_by_cols else '' - - _query = f""" - SELECT {fields_selected}, SUM(cost) as cost - FROM `{view_to_use}` - {filter_str} - GROUP BY {fields_selected} - {order_by_str} - """ - - # append LIMIT and OFFSET if present - if query.limit: - _query += ' LIMIT @limit_val' - query_parameters.append( - bigquery.ScalarQueryParameter('limit_val', 'INT64', query.limit) - ) - if query.offset: - _query += ' OFFSET @offset_val' - query_parameters.append( - bigquery.ScalarQueryParameter('offset_val', 'INT64', query.offset) - ) - - job_config = bigquery.QueryJobConfig(query_parameters=query_parameters) - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - - if query_job_result: - return [ - BillingTotalCostRecord.from_json(dict(row)) for row in query_job_result - ] - - # return empty list if no record found - return [] - - async def get_budgets_by_gcp_project( - self, field: BillingColumn, is_current_month: bool - ) -> dict[str, float]: - """ - Get budget for gcp-projects - """ - if field != BillingColumn.PROJECT or not is_current_month: - # only projects have budget and only for current month - return {} - - _query = f""" - WITH t AS ( - SELECT gcp_project, MAX(created_at) as last_created_at - FROM `{BQ_BUDGET_VIEW}` - GROUP BY 1 - ) - SELECT t.gcp_project, d.budget - FROM t inner join `{BQ_BUDGET_VIEW}` d - ON d.gcp_project = t.gcp_project AND d.created_at = t.last_created_at - """ - - query_job_result = list(self._connection.connection.query(_query).result()) - - if query_job_result: - return {row.gcp_project: row.budget for row in query_job_result} - - return {} - - async def get_last_loaded_day(self): - """Get the most recent fully loaded day in db - Go 2 days back as the data is not always available for the current day - 1 day back is not enough - """ - - _query = f""" - SELECT TIMESTAMP_ADD(MAX(day), INTERVAL -2 DAY) as last_loaded_day - FROM `{BQ_AGGREG_VIEW}` - WHERE day > TIMESTAMP_ADD( - CURRENT_TIMESTAMP(), INTERVAL @days DAY - ) + Get Total cost of selected fields for requested time interval """ + billing_table = self.table_factory(query.source, query.fields, query.filters) + return await billing_table.get_total_cost(query) - job_config = bigquery.QueryJobConfig( - query_parameters=[ - bigquery.ScalarQueryParameter( - 'days', 'INT64', -int(BQ_DAYS_BACK_OPTIMAL) - ), - ] - ) - - query_job_result = list( - self._connection.connection.query(_query, job_config=job_config).result() - ) - if query_job_result: - return str(query_job_result[0].last_loaded_day) - - return None - - async def prepare_daily_cost_subquery( - self, field, view_to_use, source, query_params - ): - """prepare daily cost subquery""" - - if source == 'gcp_billing': - # add extra filter to limit materialized view partition - # Raw BQ billing table is partitioned by part_time (when data are loaded) - # and not by end of usage time (day) - # There is a delay up to 4-5 days between part_time and day - # 7 days is added to be sure to get all data - gcp_billing_optimise_filter = """ - AND part_time >= TIMESTAMP(@last_loaded_day) - AND part_time <= TIMESTAMP_ADD( - TIMESTAMP(@last_loaded_day), INTERVAL 7 DAY - ) - """ - else: - gcp_billing_optimise_filter = '' - - # Find the last fully loaded day in the view - last_loaded_day = await self.get_last_loaded_day() - - daily_cost_field = ', day.cost as daily_cost' - daily_cost_join = f"""LEFT JOIN ( - SELECT - {field.value} as field, - cost_category, - SUM(cost) as cost - FROM - `{view_to_use}` - WHERE day = TIMESTAMP(@last_loaded_day) - {gcp_billing_optimise_filter} - GROUP BY - field, - cost_category - ) day - ON month.field = day.field - AND month.cost_category = day.cost_category - """ - - query_params.append( - bigquery.ScalarQueryParameter('last_loaded_day', 'STRING', last_loaded_day), - ) - return (last_loaded_day, query_params, daily_cost_field, daily_cost_join) - - async def execute_running_cost_query( + async def get_running_cost( self, field: BillingColumn, invoice_month: str | None = None, - source: str | None = None, - ): + source: BillingSource | None = None, + ) -> list[BillingCostBudgetRecord]: """ - Run query to get running cost of selected field - """ - # check if invoice month is valid first - if not invoice_month or not re.match(r'^\d{6}$', invoice_month): - raise ValueError('Invalid invoice month') - - invoice_month_date = datetime.strptime(invoice_month, '%Y%m') - if invoice_month != invoice_month_date.strftime('%Y%m'): - raise ValueError('Invalid invoice month') - - # get start day and current day for given invoice month - # This is to optimise the query, BQ view is partitioned by day - # and not by invoice month - start_day_date, last_day_date = get_invoice_month_range(invoice_month_date) - start_day = start_day_date.strftime('%Y-%m-%d') - last_day = last_day_date.strftime('%Y-%m-%d') - - # by default look at the normal view - if field in BillingColumn.extended_cols(): - # if any of the extendeid fields are needed use the extended view - view_to_use = BQ_AGGREG_EXT_VIEW - elif source == 'gcp_billing': - # if source is gcp_billing, - # use the view on top of the raw billing table - view_to_use = BQ_GCP_BILLING_VIEW - else: - # otherwise use the normal view - view_to_use = BQ_AGGREG_VIEW - - if source == 'gcp_billing': - # add extra filter to limit materialized view partition - # Raw BQ billing table is partitioned by part_time (when data are loaded) - # and not by end of usage time (day) - # There is a delay up to 4-5 days between part_time and day - # 7 days is added to be sure to get all data - filter_to_optimise_query = """ - part_time >= TIMESTAMP(@start_day) - AND part_time <= TIMESTAMP_ADD( - TIMESTAMP(@last_day), INTERVAL 7 DAY - ) - """ - else: - # add extra filter to limit materialized view partition - filter_to_optimise_query = """ - day >= TIMESTAMP(@start_day) - AND day <= TIMESTAMP(@last_day) - """ - - # start_day and last_day are in to optimise the query - query_params = [ - bigquery.ScalarQueryParameter('start_day', 'STRING', start_day), - bigquery.ScalarQueryParameter('last_day', 'STRING', last_day), - ] - - current_day = datetime.now().strftime('%Y-%m-%d') - is_current_month = last_day >= current_day - last_loaded_day = None - - if is_current_month: - # Only current month can have last 24 hours cost - # Last 24H in UTC time - ( - last_loaded_day, - query_params, - daily_cost_field, - daily_cost_join, - ) = await self.prepare_daily_cost_subquery( - field, view_to_use, source, query_params - ) - else: - # Do not calculate last 24H cost - daily_cost_field = ', NULL as daily_cost' - daily_cost_join = '' - - _query = f""" - SELECT - CASE WHEN month.field IS NULL THEN 'N/A' ELSE month.field END as field, - month.cost_category, - month.cost as monthly_cost - {daily_cost_field} - FROM - ( - SELECT - {field.value} as field, - cost_category, - SUM(cost) as cost - FROM - `{view_to_use}` - WHERE {filter_to_optimise_query} - AND invoice_month = @invoice_month - GROUP BY - field, - cost_category - HAVING cost > 0.1 - ) month - {daily_cost_join} - ORDER BY field ASC, daily_cost DESC, monthly_cost DESC; - """ - - query_params.append( - bigquery.ScalarQueryParameter('invoice_month', 'STRING', invoice_month) - ) - - return ( - is_current_month, - last_loaded_day, - list( - self._connection.connection.query( - _query, - job_config=bigquery.QueryJobConfig(query_parameters=query_params), - ).result() - ), - ) + Get Running costs including monthly budget + """ + billing_table = self.table_factory(source, [field]) + return await billing_table.get_running_cost(field, invoice_month) - async def append_total_running_cost( + async def get_cost_by_ar_guid( self, - field: BillingColumn, - is_current_month: bool, - last_loaded_day: str | None, - total_monthly: dict, - total_daily: dict, - total_monthly_category: dict, - total_daily_category: dict, - results: list[BillingCostBudgetRecord], - ) -> list[BillingCostBudgetRecord]: + ar_guid: str | None = None, + ) -> BillingHailBatchCostRecord: """ - Add total row: compute + storage to the results - """ - # construct ALL fields details - all_details = [] - for cat, mth_cost in total_monthly_category.items(): - all_details.append( - { - 'cost_group': abbrev_cost_category(cat), - 'cost_category': cat, - 'daily_cost': total_daily_category[cat] - if is_current_month - else None, - 'monthly_cost': mth_cost, - } + Get Costs by AR GUID + """ + ar_batch_lookup_table = BillingArBatchTable(self.connection) + + # First get all batches and the min/max day to use for the query + ( + start_day, + end_day, + batches, + ) = await ar_batch_lookup_table.get_batches_by_ar_guid(ar_guid) + + if not batches: + return BillingHailBatchCostRecord( + ar_guid=ar_guid, + batch_ids=[], + costs=[], ) - # add total row: compute + storage - results.append( - BillingCostBudgetRecord.from_json( - { - 'field': f'{BillingColumn.generate_all_title(field)}', - 'total_monthly': ( - total_monthly['C']['ALL'] + total_monthly['S']['ALL'] - ), - 'total_daily': (total_daily['C']['ALL'] + total_daily['S']['ALL']) - if is_current_month - else None, - 'compute_monthly': total_monthly['C']['ALL'], - 'compute_daily': (total_daily['C']['ALL']) - if is_current_month - else None, - 'storage_monthly': total_monthly['S']['ALL'], - 'storage_daily': (total_daily['S']['ALL']) - if is_current_month - else None, - 'details': all_details, - 'last_loaded_day': last_loaded_day, + # Then get the costs for the given AR GUID/batches from the main table + all_cols = [BillingColumn.str_to_enum(v) for v in BillingColumn.raw_cols()] + + query = BillingTotalCostQueryModel( + fields=all_cols, + source=BillingSource.RAW, + start_date=start_day.strftime('%Y-%m-%d'), + end_date=end_day.strftime('%Y-%m-%d'), + filters={ + BillingColumn.LABELS: { + 'batch_id': batches, + 'ar-guid': ar_guid, } - ) + }, + filters_op='OR', + group_by=False, + time_column=BillingTimeColumn.USAGE_END_TIME, + time_periods=BillingTimePeriods.DAY, ) - return results - - async def append_running_cost_records( - self, - field: BillingColumn, - is_current_month: bool, - last_loaded_day: str | None, - total_monthly: dict, - total_daily: dict, - field_details: dict, - results: list[BillingCostBudgetRecord], - ) -> list[BillingCostBudgetRecord]: - """ - Add all the selected field rows: compute + storage to the results - """ - # get budget map per gcp project - budgets_per_gcp_project = await self.get_budgets_by_gcp_project( - field, is_current_month + billing_table = self.table_factory(query.source, query.fields) + records = await billing_table.get_total_cost(query) + return BillingHailBatchCostRecord( + ar_guid=ar_guid, + batch_ids=batches, + costs=records, ) - # add rows by field - for key, details in field_details.items(): - compute_daily = total_daily['C'][key] if key in total_daily['C'] else 0 - storage_daily = total_daily['S'][key] if key in total_daily['S'] else 0 - compute_monthly = ( - total_monthly['C'][key] if key in total_monthly['C'] else 0 - ) - storage_monthly = ( - total_monthly['S'][key] if key in total_monthly['S'] else 0 - ) - monthly = compute_monthly + storage_monthly - budget_monthly = budgets_per_gcp_project.get(key) - - results.append( - BillingCostBudgetRecord.from_json( - { - 'field': key, - 'total_monthly': monthly, - 'total_daily': (compute_daily + storage_daily) - if is_current_month - else None, - 'compute_monthly': compute_monthly, - 'compute_daily': compute_daily, - 'storage_monthly': storage_monthly, - 'storage_daily': storage_daily, - 'details': details, - 'budget_spent': 100 * monthly / budget_monthly - if budget_monthly - else None, - 'last_loaded_day': last_loaded_day, - } - ) - ) - - return results - - async def get_running_cost( + async def get_cost_by_batch_id( self, - field: BillingColumn, - invoice_month: str | None = None, - source: str | None = None, - ) -> list[BillingCostBudgetRecord]: + batch_id: str | None = None, + ) -> BillingHailBatchCostRecord: """ - Get currently running cost of selected field + Get Costs by Batch ID """ + ar_batch_lookup_table = BillingArBatchTable(self.connection) - # accept only Topic, Dataset or Project at this stage - if field not in ( - BillingColumn.TOPIC, - BillingColumn.PROJECT, - BillingColumn.DATASET, - ): - raise ValueError('Invalid field only topic, dataset or project allowed') + # First get all batches and the min/max day to use for the query + ar_guid = await ar_batch_lookup_table.get_ar_guid_by_batch_id(batch_id) + # The get all batches for the ar_guid ( - is_current_month, - last_loaded_day, - query_job_result, - ) = await self.execute_running_cost_query(field, invoice_month, source) - if not query_job_result: - # return empty list - return [] - - # prepare data - results: list[BillingCostBudgetRecord] = [] - - # reformat last_loaded_day if present - last_loaded_day = reformat_datetime( - last_loaded_day, '%Y-%m-%d %H:%M:%S+00:00', '%b %d' - ) - - total_monthly: dict[str, Counter[str]] = defaultdict(Counter) - total_daily: dict[str, Counter[str]] = defaultdict(Counter) - field_details: dict[str, list[Any]] = defaultdict(list) - total_monthly_category: Counter[str] = Counter() - total_daily_category: Counter[str] = Counter() - - for row in query_job_result: - if row.field not in field_details: - field_details[row.field] = [] - - cost_group = abbrev_cost_category(row.cost_category) - - field_details[row.field].append( - { - 'cost_group': cost_group, - 'cost_category': row.cost_category, - 'daily_cost': row.daily_cost if is_current_month else None, - 'monthly_cost': row.monthly_cost, + start_day, + end_day, + batches, + ) = await ar_batch_lookup_table.get_batches_by_ar_guid(ar_guid) + + if not batches: + return BillingHailBatchCostRecord(ar_guid=ar_guid, batch_ids=[], costs=[]) + + # Then get the costs for the given AR GUID/batches from the main table + all_cols = [BillingColumn.str_to_enum(v) for v in BillingColumn.raw_cols()] + + query = BillingTotalCostQueryModel( + fields=all_cols, + source=BillingSource.RAW, + start_date=start_day.strftime('%Y-%m-%d'), + end_date=end_day.strftime('%Y-%m-%d'), + filters={ + BillingColumn.LABELS: { + 'batch_id': batches, + 'ar-guid': ar_guid, } - ) - - total_monthly_category[row.cost_category] += row.monthly_cost - if row.daily_cost: - total_daily_category[row.cost_category] += row.daily_cost - - # cost groups totals - total_monthly[cost_group]['ALL'] += row.monthly_cost - total_monthly[cost_group][row.field] += row.monthly_cost - if row.daily_cost and is_current_month: - total_daily[cost_group]['ALL'] += row.daily_cost - total_daily[cost_group][row.field] += row.daily_cost - - # add total row: compute + storage - results = await self.append_total_running_cost( - field, - is_current_month, - last_loaded_day, - total_monthly, - total_daily, - total_monthly_category, - total_daily_category, - results, + }, + filters_op='OR', + group_by=False, + time_column=BillingTimeColumn.USAGE_END_TIME, + time_periods=BillingTimePeriods.DAY, ) - - # add rest of the records: compute + storage - results = await self.append_running_cost_records( - field, - is_current_month, - last_loaded_day, - total_monthly, - total_daily, - field_details, - results, + billing_table = self.table_factory(query.source, query.fields) + records = await billing_table.get_total_cost(query) + return BillingHailBatchCostRecord( + ar_guid=ar_guid, + batch_ids=batches, + costs=records, ) - - return results diff --git a/db/python/layers/family.py b/db/python/layers/family.py index d90aa2612..3d77576b0 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -1,6 +1,6 @@ # pylint: disable=used-before-assignment import logging -from typing import List, Union, Optional, Dict +from typing import Dict, List, Optional, Union from db.python.connect import Connection from db.python.layers.base import BaseLayer @@ -8,11 +8,11 @@ from db.python.tables.family import FamilyTable from db.python.tables.family_participant import FamilyParticipantTable from db.python.tables.participant import ParticipantTable -from db.python.tables.project import ProjectId -from db.python.tables.sample import SampleTable, SampleFilter +from db.python.tables.sample import SampleFilter, SampleTable from db.python.utils import GenericFilter -from models.models.family import PedRowInternal, FamilyInternal +from models.models.family import FamilyInternal, PedRowInternal from models.models.participant import ParticipantUpsertInternal +from models.models.project import ProjectId class PedRow: diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index 44d6d4db2..3698e0cc8 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -4,7 +4,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from db.python.connect import NoOpAenter, NotFoundError from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer from db.python.tables.family import FamilyTable @@ -12,9 +11,10 @@ from db.python.tables.participant import ParticipantTable from db.python.tables.participant_phenotype import ParticipantPhenotypeTable from db.python.tables.sample import SampleTable -from db.python.utils import ProjectId, split_generic_terms +from db.python.utils import NoOpAenter, NotFoundError, split_generic_terms from models.models.family import PedRowInternal from models.models.participant import ParticipantInternal, ParticipantUpsertInternal +from models.models.project import ProjectId HPO_REGEX_MATCHER = re.compile(r'HP\:\d+$') @@ -571,7 +571,6 @@ async def get_external_participant_id_to_internal_sequencing_group_id_map( async def upsert_participant( self, participant: ParticipantUpsertInternal, - author: str = None, project: ProjectId = None, check_project_id: bool = True, open_transaction=True, @@ -602,7 +601,6 @@ async def upsert_participant( reported_gender=participant.reported_gender, meta=participant.meta, karyotype=participant.karyotype, - author=author, ) else: @@ -612,7 +610,6 @@ async def upsert_participant( reported_gender=participant.reported_gender, karyotype=participant.karyotype, meta=participant.meta, - author=author, project=project, ) @@ -623,7 +620,6 @@ async def upsert_participant( await slayer.upsert_samples( participant.samples, - author=author, project=project, check_project_id=False, open_transaction=False, @@ -802,7 +798,6 @@ async def add_participant_to_family( maternal_id=maternal_id, affected=affected, notes=None, - author=None, ) @staticmethod diff --git a/db/python/layers/sample.py b/db/python/layers/sample.py index 4b4d2f4f1..dbb3dbcf0 100644 --- a/db/python/layers/sample.py +++ b/db/python/layers/sample.py @@ -2,18 +2,16 @@ from typing import Any from api.utils import group_by -from db.python.connect import NotFoundError from db.python.layers.assay import AssayLayer from db.python.layers.base import BaseLayer, Connection from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.assay import NoOpAenter -from db.python.tables.project import ProjectId, ProjectPermissionsTable -from db.python.tables.sample import SampleTable, SampleFilter -from db.python.utils import GenericFilter +from db.python.tables.project import ProjectPermissionsTable +from db.python.tables.sample import SampleFilter, SampleTable +from db.python.utils import GenericFilter, NotFoundError +from models.models.project import ProjectId from models.models.sample import SampleInternal, SampleUpsertInternal -from models.utils.sample_id_format import ( - sample_id_format_list, -) +from models.utils.sample_id_format import sample_id_format_list class SampleLayer(BaseLayer): @@ -22,7 +20,7 @@ class SampleLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) self.st: SampleTable = SampleTable(connection) - self.pt = ProjectPermissionsTable(connection.connection) + self.pt = ProjectPermissionsTable(connection) self.connection = connection # GETS @@ -220,7 +218,6 @@ async def get_samples_create_date( async def upsert_sample( self, sample: SampleUpsertInternal, - author: str = None, project: ProjectId = None, process_sequencing_groups: bool = True, process_assays: bool = True, @@ -239,7 +236,6 @@ async def upsert_sample( active=True, meta=sample.meta, participant_id=sample.participant_id, - author=author, project=project, ) else: @@ -276,7 +272,6 @@ async def upsert_samples( self, samples: list[SampleUpsertInternal], open_transaction: bool = True, - author: str = None, project: ProjectId = None, check_project_id=True, ) -> list[SampleUpsertInternal]: @@ -300,7 +295,6 @@ async def upsert_samples( for sample in samples: await self.upsert_sample( sample, - author=author, project=project, process_sequencing_groups=False, process_assays=False, @@ -329,20 +323,18 @@ async def merge_samples( self, id_keep: int, id_merge: int, - author=None, check_project_id=True, ): """Merge two samples into one another""" if check_project_id: projects = await self.st.get_project_ids_for_sample_ids([id_keep, id_merge]) await self.ptable.check_access_to_project_ids( - user=author or self.author, project_ids=projects, readonly=False + user=self.author, project_ids=projects, readonly=False ) return await self.st.merge_samples( id_keep=id_keep, id_merge=id_merge, - author=author, ) async def update_many_participant_ids( diff --git a/db/python/layers/search.py b/db/python/layers/search.py index eb0ea2742..3e94a7ca9 100644 --- a/db/python/layers/search.py +++ b/db/python/layers/search.py @@ -1,7 +1,7 @@ import asyncio from typing import List, Optional -from db.python.connect import NotFoundError +from db.python.utils import NotFoundError from db.python.layers.base import BaseLayer, Connection from db.python.tables.family import FamilyTable from db.python.tables.participant import ParticipantTable @@ -28,7 +28,7 @@ class SearchLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) - self.pt = ProjectPermissionsTable(connection.connection) + self.pt = ProjectPermissionsTable(connection) self.connection = connection @staticmethod diff --git a/db/python/layers/seqr.py b/db/python/layers/seqr.py index 2bf1de14d..cad435852 100644 --- a/db/python/layers/seqr.py +++ b/db/python/layers/seqr.py @@ -29,7 +29,7 @@ from db.python.layers.participant import ParticipantLayer from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.analysis import AnalysisFilter -from db.python.tables.project import Project, ProjectPermissionsTable +from db.python.tables.project import Project from db.python.utils import GenericFilter from models.enums import AnalysisStatus @@ -122,8 +122,7 @@ async def sync_dataset( raise ValueError('Seqr synchronisation is not configured in metamist') token = self.generate_seqr_auth_token() - pptable = ProjectPermissionsTable(connection=self.connection.connection) - project = await pptable.get_and_check_access_to_project_for_id( + project = await self.ptable.get_and_check_access_to_project_for_id( self.connection.author, project_id=self.connection.project, readonly=True, diff --git a/db/python/layers/sequencing_group.py b/db/python/layers/sequencing_group.py index 2e7d073f5..e68150c76 100644 --- a/db/python/layers/sequencing_group.py +++ b/db/python/layers/sequencing_group.py @@ -1,6 +1,6 @@ from datetime import date -from db.python.connect import Connection, NotFoundError +from db.python.connect import Connection from db.python.layers.assay import AssayLayer from db.python.layers.base import BaseLayer from db.python.tables.assay import AssayTable, NoOpAenter @@ -9,11 +9,12 @@ SequencingGroupFilter, SequencingGroupTable, ) -from db.python.utils import ProjectId +from db.python.utils import NotFoundError +from models.models.project import ProjectId from models.models.sequencing_group import ( SequencingGroupInternal, - SequencingGroupUpsertInternal, SequencingGroupInternalId, + SequencingGroupUpsertInternal, ) from models.utils.sequencing_group_id_format import sequencing_group_id_format @@ -261,7 +262,6 @@ async def recreate_sequencing_group_with_new_assays( platform=seqgroup.platform, meta={**seqgroup.meta, **meta}, assay_ids=assays, - author=self.author, open_transaction=False, ) diff --git a/db/python/layers/web.py b/db/python/layers/web.py index d977de09d..20e6c82fd 100644 --- a/db/python/layers/web.py +++ b/db/python/layers/web.py @@ -7,12 +7,12 @@ from datetime import date from api.utils import group_by -from db.python.connect import DbBase from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer from db.python.layers.seqr import SeqrLayer from db.python.tables.analysis import AnalysisTable from db.python.tables.assay import AssayTable +from db.python.tables.base import DbBase from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sequencing_group import SequencingGroupTable from models.models import ( @@ -22,6 +22,7 @@ NestedSampleInternal, NestedSequencingGroupInternal, SearchItem, + parse_sql_bool, ) from models.models.web import ProjectSummaryInternal, WebProject @@ -189,11 +190,7 @@ def _project_summary_process_sample_rows( created_date=str(sample_id_start_times.get(s['id'], '')), sequencing_groups=sg_models_by_sample_id.get(s['id'], []), non_sequencing_assays=filtered_assay_models_by_sid.get(s['id'], []), - active=bool( - ord(s['active']) - if isinstance(s['active'], (str, bytes, bytearray)) - else bool(s['active']) - ), + active=parse_sql_bool(s['active']), ) for s in sample_rows ] @@ -282,7 +279,7 @@ async def get_project_summary( # do initial query to get sample info sampl = SampleLayer(self._connection) sample_query, values = self._project_summary_sample_query(grid_filter) - ptable = ProjectPermissionsTable(self.connection) + ptable = ProjectPermissionsTable(self._connection) project_db = await ptable.get_and_check_access_to_project_for_id( self.author, self.project, readonly=True ) diff --git a/db/python/tables/analysis.py b/db/python/tables/analysis.py index b079a0e42..ee445345e 100644 --- a/db/python/tables/analysis.py +++ b/db/python/tables/analysis.py @@ -4,16 +4,18 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Set, Tuple -from db.python.connect import DbBase, NotFoundError -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase from db.python.utils import ( GenericFilter, GenericFilterModel, GenericMetaFilter, + NotFoundError, to_db_json, ) from models.enums import AnalysisStatus from models.models.analysis import AnalysisInternal +from models.models.audit_log import AuditLogInternal +from models.models.project import ProjectId @dataclasses.dataclass @@ -60,7 +62,6 @@ async def create_analysis( meta: Optional[Dict[str, Any]] = None, output: str = None, active: bool = True, - author: str = None, project: ProjectId = None, ) -> int: """ @@ -73,14 +74,11 @@ async def create_analysis( ('status', status.value), ('meta', to_db_json(meta or {})), ('output', output), - ('author', author or self.author), + ('audit_log_id', await self.audit_log_id()), ('project', project or self.project), ('active', active if active is not None else True), ] - if author is not None: - kv_pairs.append(('on_behalf_of', self.author)) - if status == AnalysisStatus.COMPLETED: kv_pairs.append(('timestamp_completed', datetime.datetime.utcnow())) @@ -110,11 +108,19 @@ async def add_sequencing_groups_to_analysis( """Add samples to an analysis (through the linked table)""" _query = """ INSERT INTO analysis_sequencing_group - (analysis_id, sequencing_group_id) - VALUES (:aid, :sid) + (analysis_id, sequencing_group_id, audit_log_id) + VALUES (:aid, :sid, :audit_log_id) """ - values = map(lambda sid: {'aid': analysis_id, 'sid': sid}, sequencing_group_ids) + audit_log_id = await self.audit_log_id() + values = map( + lambda sid: { + 'aid': analysis_id, + 'sid': sid, + 'audit_log_id': audit_log_id, + }, + sequencing_group_ids, + ) await self.connection.execute_many(_query, list(values)) async def find_sgs_in_joint_call_or_es_index_up_to_date( @@ -140,18 +146,17 @@ async def update_analysis( meta: Dict[str, Any] = None, active: bool = None, output: Optional[str] = None, - author: Optional[str] = None, ): """ Update the status of an analysis, set timestamp_completed if relevant """ fields: Dict[str, Any] = { - 'author': self.author or author, - 'on_behalf_of': self.author, 'analysis_id': analysis_id, + 'on_behalf_of': self.author, + 'audit_log_id': await self.audit_log_id(), } - setters = ['author = :author', 'on_behalf_of = :on_behalf_of'] + setters = ['audit_log_id = :audit_log_id', 'on_behalf_of = :on_behalf_of'] if status: setters.append('status = :status') fields['status'] = status.value @@ -391,7 +396,7 @@ async def get_analyses_for_samples( self, sample_ids: list[int], analysis_type: str | None, - status: AnalysisStatus, + status: AnalysisStatus | None, ) -> tuple[set[ProjectId], list[AnalysisInternal]]: """ Get relevant analyses for a sample, optional type / status filters @@ -497,8 +502,8 @@ async def get_analysis_runner_log( values['project_ids'] = project_ids if author: - wheres.append('author = :author') - values['author'] = author + wheres.append('audit_log_id = :audit_log_id') + values['audit_log_id'] = await self.audit_log_id() if output_dir: wheres.append('(output = :output OR output LIKE :output_like)') @@ -589,3 +594,11 @@ async def get_sg_add_to_project_es_index( rows = await self.connection.fetch_all(_query, {'sg_ids': sg_ids}) return {r['sg_id']: r['timestamp_completed'].date() for r in rows} + + async def get_audit_log_for_analysis_ids( + self, analysis_ids: list[int] + ) -> dict[int, list[AuditLogInternal]]: + """ + Get audit logs for analysis IDs + """ + return await self.get_all_audit_logs_for_table('analysis', analysis_ids) diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py index c0a3a27d7..dc0604ef3 100644 --- a/db/python/tables/assay.py +++ b/db/python/tables/assay.py @@ -4,15 +4,17 @@ from collections import defaultdict from typing import Any -from db.python.connect import DbBase, NotFoundError, NoOpAenter -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase from db.python.utils import ( - to_db_json, - GenericFilterModel, GenericFilter, + GenericFilterModel, GenericMetaFilter, + NoOpAenter, + NotFoundError, + to_db_json, ) from models.models.assay import AssayInternal +from models.models.project import ProjectId REPLACEMENT_KEY_INVALID_CHARS = re.compile(r'[^\w\d_]') @@ -102,7 +104,7 @@ async def query( drow = dict(row) project_ids.add(drow.pop('project')) assay = AssayInternal.from_db(drow) - assay.external_ids = seq_eids.get(assay.id, {}) + assay.external_ids = seq_eids.get(assay.id, {}) if assay.id else {} assays.append(assay) return project_ids, assays @@ -137,7 +139,7 @@ async def get_assay_by_id(self, assay_id: int) -> tuple[ProjectId, AssayInternal return pjcts.pop(), assays.pop() async def get_assay_by_external_id( - self, external_sequence_id: str, project: int = None + self, external_sequence_id: str, project: ProjectId | None = None ) -> AssayInternal: """Get assay by EXTERNAL ID""" if not (project or self.project): @@ -214,7 +216,6 @@ async def insert_assay( external_ids: dict[str, str] | None, assay_type: str, meta: dict[str, Any] | None, - author: str | None = None, project: int | None = None, open_transaction: bool = True, ) -> int: @@ -243,8 +244,8 @@ async def insert_assay( _query = """\ INSERT INTO assay - (sample_id, meta, type, author) - VALUES (:sample_id, :meta, :type, :author) + (sample_id, meta, type, audit_log_id) + VALUES (:sample_id, :meta, :type, :audit_log_id) RETURNING id; """ @@ -257,7 +258,7 @@ async def insert_assay( 'sample_id': sample_id, 'meta': to_db_json(meta), 'type': assay_type, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), }, ) @@ -271,16 +272,17 @@ async def insert_assay( _eid_query = """ INSERT INTO assay_external_id - (project, assay_id, external_id, name, author) - VALUES (:project, :assay_id, :external_id, :name, :author); + (project, assay_id, external_id, name, audit_log_id) + VALUES (:project, :assay_id, :external_id, :name, :audit_log_id); """ + audit_log_id = await self.audit_log_id() eid_values = [ { 'project': project or self.project, 'assay_id': id_of_new_assay, 'external_id': eid, 'name': name.lower(), - 'author': author or self.author, + 'audit_log_id': audit_log_id, } for name, eid in external_ids.items() ] @@ -290,7 +292,7 @@ async def insert_assay( return id_of_new_assay async def insert_many_assays( - self, assays: list[AssayInternal], author=None, open_transaction: bool = True + self, assays: list[AssayInternal], open_transaction: bool = True ): """Insert many sequencing, returning no IDs""" with_function = self.connection.transaction if open_transaction else NoOpAenter @@ -305,13 +307,14 @@ async def insert_many_assays( sample_id=assay.sample_id, external_ids=assay.external_ids, meta=assay.meta, - author=author, assay_type=assay.type, open_transaction=False, ) ) return assay_ids + # endregion INSERTS + async def update_assay( self, assay_id: int, @@ -322,15 +325,18 @@ async def update_assay( sample_id: int | None = None, project: ProjectId | None = None, open_transaction: bool = True, - author=None, ): """Update an assay""" with_function = self.connection.transaction if open_transaction else NoOpAenter async with with_function(): - fields = {'assay_id': assay_id, 'author': author or self.author} + audit_log_id = await self.audit_log_id() + fields: dict[str, Any] = { + 'assay_id': assay_id, + 'audit_log_id': audit_log_id, + } - updaters = ['author = :author'] + updaters = ['audit_log_id = :audit_log_id'] if meta is not None: updaters.append('meta = JSON_MERGE_PATCH(COALESCE(meta, "{}"), :meta)') fields['meta'] = to_db_json(meta) @@ -365,7 +371,20 @@ async def update_assay( } if to_delete: + _assay_eid_update_before_delete = """ + UPDATE assay_external_id + SET audit_log_id = :audit_log_id + WHERE assay_id = :assay_id AND name in :names + """ _delete_query = 'DELETE FROM assay_external_id WHERE assay_id = :assay_id AND name in :names' + await self.connection.execute( + _assay_eid_update_before_delete, + { + 'assay_id': assay_id, + 'names': list(to_delete), + 'audit_log_id': audit_log_id, + }, + ) await self.connection.execute( _delete_query, {'assay_id': assay_id, 'names': list(to_delete)}, @@ -377,17 +396,18 @@ async def update_assay( ) _update_query = """\ - INSERT INTO assay_external_id (project, assay_id, external_id, name, author) - VALUES (:project, :assay_id, :external_id, :name, :author) - ON DUPLICATE KEY UPDATE external_id = :external_id, author = :author + INSERT INTO assay_external_id (project, assay_id, external_id, name, audit_log_id) + VALUES (:project, :assay_id, :external_id, :name, :audit_log_id) + ON DUPLICATE KEY UPDATE external_id = :external_id, audit_log_id = :audit_log_id """ + audit_log_id = await self.audit_log_id() values = [ { 'project': project, 'assay_id': assay_id, 'external_id': eid, 'name': name, - 'author': author or self.author, + 'audit_log_id': audit_log_id, } for name, eid in to_update.items() ] @@ -397,13 +417,13 @@ async def update_assay( async def get_assays_by( self, - assay_ids: list[int] = None, - sample_ids: list[int] = None, - assay_types: list[str] = None, - assay_meta: dict[str, Any] = None, - sample_meta: dict[str, Any] = None, - external_assay_ids: list[str] = None, - project_ids: list[int] = None, + assay_ids: list[int] | None = None, + sample_ids: list[int] | None = None, + assay_types: list[str] | None = None, + assay_meta: dict[str, Any] | None = None, + sample_meta: dict[str, Any] | None = None, + external_assay_ids: list[str] | None = None, + project_ids: list[int] | None = None, active: bool = True, ) -> tuple[list[ProjectId], list[AssayInternal]]: """Get sequences by some criteria""" diff --git a/db/python/tables/audit_log.py b/db/python/tables/audit_log.py new file mode 100644 index 000000000..f3f24038e --- /dev/null +++ b/db/python/tables/audit_log.py @@ -0,0 +1,74 @@ +from typing import Any + +from db.python.tables.base import DbBase +from db.python.utils import to_db_json +from models.models.audit_log import AuditLogInternal +from models.models.project import ProjectId + + +class AuditLogTable(DbBase): + """ + Capture Analysis table operations and queries + """ + + table_name = 'audit_log' + + async def get_projects_for_ids(self, audit_log_ids: list[int]) -> set[ProjectId]: + """Get project IDs for sampleIds (mostly for checking auth)""" + _query = """ + SELECT DISTINCT auth_project + FROM audit_log + WHERE id in :audit_log_ids + """ + if len(audit_log_ids) == 0: + raise ValueError('Received no audit log IDs') + rows = await self.connection.fetch_all(_query, {'audit_log_ids': audit_log_ids}) + return {r['project'] for r in rows} + + async def get_audit_logs_for_ids( + self, audit_log_ids: list[int] + ) -> list[AuditLogInternal]: + """Get project IDs for sampleIds (mostly for checking auth)""" + _query = """ + SELECT id, timestamp, author, on_behalf_of, ar_guid, comment, auth_project + FROM audit_log + WHERE id in :audit_log_ids + """ + if len(audit_log_ids) == 0: + raise ValueError('Received no audit log IDs') + rows = await self.connection.fetch_all(_query, {'audit_log_ids': audit_log_ids}) + return [AuditLogInternal.from_db(dict(r)) for r in rows] + + async def create_audit_log( + self, + author: str, + on_behalf_of: str | None, + ar_guid: str | None, + comment: str | None, + project: ProjectId | None, + meta: dict[str, Any] | None = None, + ) -> int: + """ + Create a new audit log entry + """ + + _query = """ + INSERT INTO audit_log + (author, on_behalf_of, ar_guid, comment, auth_project, meta) + VALUES + (:author, :on_behalf_of, :ar_guid, :comment, :project, :meta) + RETURNING id + """ + audit_log_id = await self.connection.fetch_val( + _query, + { + 'author': author, + 'on_behalf_of': on_behalf_of, + 'ar_guid': ar_guid, + 'comment': comment, + 'project': project, + 'meta': to_db_json(meta or {}), + }, + ) + + return audit_log_id diff --git a/db/python/tables/base.py b/db/python/tables/base.py new file mode 100644 index 000000000..e29353de8 --- /dev/null +++ b/db/python/tables/base.py @@ -0,0 +1,76 @@ +from collections import defaultdict + +import databases + +from db.python.connect import Connection +from db.python.utils import InternalError +from models.models.audit_log import AuditLogInternal + + +class DbBase: + """Base class for table subclasses""" + + def __init__(self, connection: Connection): + if connection is None: + raise InternalError( + f'No connection was provided to the table {self.__class__.__name__!r}' + ) + if not isinstance(connection, Connection): + raise InternalError( + f'Expected connection type Connection, received {type(connection)}, ' + f'did you mean to call self._connection?' + ) + + self._connection = connection + self.connection: databases.Database = connection.connection + self.author = connection.author + self.project = connection.project + + if self.author is None: + raise InternalError(f'Must provide author to {self.__class__.__name__}') + + async def audit_log_id(self): + """ + Get audit_log ID (or fail otherwise) + """ + return await self._connection.audit_log_id() + + # piped from the connection + + @staticmethod + def escape_like_term(query: str): + """ + Escape meaningful keys when using LIKE with a user supplied input + """ + return query.replace('%', '\\%').replace('_', '\\_') + + async def get_all_audit_logs_for_table( + self, table: str, ids: list[int], id_field='id' + ) -> dict[int, list[AuditLogInternal]]: + """ + Get all audit logs for values from a table + """ + _query = f""" + SELECT + t.{id_field} as table_id, + al.id as id, + al.author as author, + al.on_behalf_of as on_behalf_of, + al.timestamp as timestamp, + al.ar_guid as ar_guid, + al.comment as comment, + al.auth_project as auth_project, + al.meta as meta + FROM {table} FOR SYSTEM_TIME ALL t + INNER JOIN audit_log al + ON al.id = t.audit_log_id + WHERE t.{id_field} in :ids + """.strip() + rows = await self.connection.fetch_all(_query, {'ids': ids}) + by_id = defaultdict(list) + for r in rows: + row = dict(r) + id_value = row.pop('table_id') + by_id[id_value].append(AuditLogInternal.from_db(row)) + + return by_id diff --git a/db/python/tables/billing.py b/db/python/tables/billing.py deleted file mode 100644 index 54402a85c..000000000 --- a/db/python/tables/billing.py +++ /dev/null @@ -1,18 +0,0 @@ -import dataclasses - -from db.python.utils import ( - GenericFilter, - GenericFilterModel, -) - - -@dataclasses.dataclass -class BillingFilter(GenericFilterModel): - """Filter for billing""" - - topic: GenericFilter[str] = None - date: GenericFilter[str] = None - cost_category: GenericFilter[str] = None - - def __hash__(self): # pylint: disable=useless-parent-delegation - return super().__hash__() diff --git a/db/python/tables/bq/billing_ar_batch.py b/db/python/tables/bq/billing_ar_batch.py new file mode 100644 index 000000000..d9326b6b3 --- /dev/null +++ b/db/python/tables/bq/billing_ar_batch.py @@ -0,0 +1,69 @@ +from datetime import datetime, timedelta + +from google.cloud import bigquery + +from api.settings import BQ_BATCHES_VIEW +from db.python.tables.bq.billing_base import BillingBaseTable + + +class BillingArBatchTable(BillingBaseTable): + """Billing AR - BatchID lookup Big Query table""" + + table_name = BQ_BATCHES_VIEW + + def get_table_name(self): + """Get table name""" + return self.table_name + + async def get_batches_by_ar_guid( + self, ar_guid: str + ) -> tuple[datetime, datetime, list[str]]: + """ + Get batches for given ar_guid + """ + _query = f""" + SELECT + batch_id, + MIN(min_day) as start_day, + MAX(max_day) as end_day + FROM `{self.table_name}` + WHERE ar_guid = @ar_guid + AND batch_id IS NOT NULL + GROUP BY batch_id + ORDER BY batch_id; + """ + + query_parameters = [ + bigquery.ScalarQueryParameter('ar_guid', 'STRING', ar_guid), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + start_day = min((row.start_day for row in query_job_result)) + end_day = max((row.end_day for row in query_job_result)) + timedelta(days=1) + return start_day, end_day, [row.batch_id for row in query_job_result] + + # return empty list if no record found + return None, None, [] + + async def get_ar_guid_by_batch_id(self, batch_id: str) -> str: + """ + Get ar_guid for given batch_id + """ + _query = f""" + SELECT ar_guid + FROM `{self.table_name}` + WHERE batch_id = @batch_id + AND ar_guid IS NOT NULL + LIMIT 1; + """ + + query_parameters = [ + bigquery.ScalarQueryParameter('batch_id', 'STRING', batch_id), + ] + query_job_result = self._execute_query(_query, query_parameters) + if query_job_result: + return query_job_result[0]['ar_guid'] + + # return None if no ar_guid found + return None diff --git a/db/python/tables/bq/billing_base.py b/db/python/tables/bq/billing_base.py new file mode 100644 index 000000000..335603c3b --- /dev/null +++ b/db/python/tables/bq/billing_base.py @@ -0,0 +1,695 @@ +import re +from abc import ABCMeta, abstractmethod +from collections import Counter, defaultdict, namedtuple +from datetime import datetime +from typing import Any + +from google.cloud import bigquery + +from api.settings import BQ_BUDGET_VIEW, BQ_DAYS_BACK_OPTIMAL +from api.utils.dates import get_invoice_month_range, reformat_datetime +from db.python.gcp_connect import BqDbBase +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.function_bq_filter import FunctionBQFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from models.enums import BillingTimeColumn, BillingTimePeriods +from models.models import ( + BillingColumn, + BillingCostBudgetRecord, + BillingCostDetailsRecord, + BillingTotalCostQueryModel, +) + +# Label added to each Billing Big Query request, +# so we can track the cost of metamist-api BQ usage +BQ_LABELS = {'source': 'metamist-api'} + + +# Day Time details used in grouping and parsing formulas +TimeGroupingDetails = namedtuple( + 'TimeGroupingDetails', ['field', 'formula', 'separator'] +) + +# constants to abbrevate (S)tores and (C)ompute +STORAGE = 'S' +COMPUTE = 'C' + + +def abbrev_cost_category(cost_category: str) -> str: + """abbreviate cost category""" + return STORAGE if cost_category == 'Cloud Storage' else COMPUTE + + +def prepare_time_periods( + query: BillingTotalCostQueryModel, +) -> TimeGroupingDetails: + """Prepare Time periods grouping and parsing formulas""" + time_column: BillingTimeColumn = query.time_column or BillingTimeColumn.DAY + + # Based on specified time period, add the corresponding column + if query.time_periods == BillingTimePeriods.DAY: + return TimeGroupingDetails( + field=f'FORMAT_DATE("%Y-%m-%d", {time_column.value}) as day', + formula='PARSE_DATE("%Y-%m-%d", day) as day', + separator=',', + ) + + if query.time_periods == BillingTimePeriods.WEEK: + return TimeGroupingDetails( + field=f'FORMAT_DATE("%Y%W", {time_column.value}) as day', + formula='PARSE_DATE("%Y%W", day) as day', + separator=',', + ) + + if query.time_periods == BillingTimePeriods.MONTH: + return TimeGroupingDetails( + field=f'FORMAT_DATE("%Y%m", {time_column.value}) as day', + formula='PARSE_DATE("%Y%m", day) as day', + separator=',', + ) + + if query.time_periods == BillingTimePeriods.INVOICE_MONTH: + return TimeGroupingDetails( + field='invoice_month as day', + formula='PARSE_DATE("%Y%m", day) as day', + separator=',', + ) + + return TimeGroupingDetails('', '', '') + + +def time_optimisation_parameter() -> bigquery.ScalarQueryParameter: + """ + BQ tables and views are partitioned by day, to avoid full scans + we need to limit the amount of data scanned + """ + return bigquery.ScalarQueryParameter('days', 'INT64', -int(BQ_DAYS_BACK_OPTIMAL)) + + +class BillingBaseTable(BqDbBase): + """Billing Base Table + This is abstract class, it should not be instantiated + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def get_table_name(self): + """Get table name""" + raise NotImplementedError('Calling Abstract method directly') + + def _execute_query( + self, query: str, params: list[Any] = None, results_as_list: bool = True + ) -> list[Any]: + """Execute query, add BQ labels""" + if params: + job_config = bigquery.QueryJobConfig( + query_parameters=params, labels=BQ_LABELS + ) + else: + job_config = bigquery.QueryJobConfig(labels=BQ_LABELS) + + if results_as_list: + return list( + self._connection.connection.query(query, job_config=job_config).result() + ) + + # otherwise return as BQ iterator + return self._connection.connection.query(query, job_config=job_config) + + def _query_to_partitioned_filter( + self, query: BillingTotalCostQueryModel + ) -> BillingFilter: + """ + By default views are partitioned by 'day', + if different then overwrite in the subclass + """ + billing_filter = query.to_filter() + + # initial partition filter + billing_filter.day = GenericBQFilter[datetime]( + gte=datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None, + lte=datetime.strptime(query.end_date, '%Y-%m-%d') + if query.end_date + else None, + ) + return billing_filter + + def _filter_to_optimise_query(self) -> str: + """Filter string to optimise BQ query""" + return 'day >= TIMESTAMP(@start_day) AND day <= TIMESTAMP(@last_day)' + + def _last_loaded_day_filter(self) -> str: + """Last Loaded day filter string""" + return 'day = TIMESTAMP(@last_loaded_day)' + + def _convert_output(self, query_job_result): + """Convert query result to json""" + if not query_job_result or query_job_result.result().total_rows == 0: + # return empty list if no record found + return [] + + records = query_job_result.result() + results = [] + + def transform_labels(row): + return {r['key']: r['value'] for r in row} + + for record in records: + drec = dict(record) + if 'labels' in drec: + drec.update(transform_labels(drec['labels'])) + + results.append(drec) + + return results + + async def _budgets_by_gcp_project( + self, field: BillingColumn, is_current_month: bool + ) -> dict[str, float]: + """ + Get budget for gcp-projects + """ + if field != BillingColumn.GCP_PROJECT or not is_current_month: + # only projects have budget and only for current month + return {} + + _query = f""" + WITH t AS ( + SELECT gcp_project, MAX(created_at) as last_created_at + FROM `{BQ_BUDGET_VIEW}` + GROUP BY gcp_project + ) + SELECT t.gcp_project, d.budget + FROM t inner join `{BQ_BUDGET_VIEW}` d + ON d.gcp_project = t.gcp_project AND d.created_at = t.last_created_at + """ + + query_job_result = self._execute_query(_query) + if query_job_result: + return {row.gcp_project: row.budget for row in query_job_result} + + return {} + + async def _last_loaded_day(self): + """Get the most recent fully loaded day in db + Go 2 days back as the data is not always available for the current day + 1 day back is not enough + """ + + _query = f""" + SELECT TIMESTAMP_ADD(MAX(day), INTERVAL -2 DAY) as last_loaded_day + FROM `{self.get_table_name()}` + WHERE day > TIMESTAMP_ADD( + CURRENT_TIMESTAMP(), INTERVAL @days DAY + ) + """ + + query_parameters = [ + time_optimisation_parameter(), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + return str(query_job_result[0].last_loaded_day) + + return None + + def _prepare_daily_cost_subquery(self, field, query_params, last_loaded_day): + """prepare daily cost subquery""" + + daily_cost_field = ', day.cost as daily_cost' + daily_cost_join = f"""LEFT JOIN ( + SELECT + {field.value} as field, + cost_category, + SUM(cost) as cost + FROM + `{self.get_table_name()}` + WHERE {self._last_loaded_day_filter()} + GROUP BY + field, + cost_category + ) day + ON month.field = day.field + AND month.cost_category = day.cost_category + """ + + query_params.append( + bigquery.ScalarQueryParameter('last_loaded_day', 'STRING', last_loaded_day), + ) + return (query_params, daily_cost_field, daily_cost_join) + + async def _execute_running_cost_query( + self, + field: BillingColumn, + invoice_month: str | None = None, + ): + """ + Run query to get running cost of selected field + """ + # check if invoice month is valid first + if not invoice_month or not re.match(r'^\d{6}$', invoice_month): + raise ValueError('Invalid invoice month') + + invoice_month_date = datetime.strptime(invoice_month, '%Y%m') + if invoice_month != invoice_month_date.strftime('%Y%m'): + raise ValueError('Invalid invoice month') + + # get start day and current day for given invoice month + # This is to optimise the query, BQ view is partitioned by day + # and not by invoice month + start_day_date, last_day_date = get_invoice_month_range(invoice_month_date) + start_day = start_day_date.strftime('%Y-%m-%d') + last_day = last_day_date.strftime('%Y-%m-%d') + + # start_day and last_day are in to optimise the query + query_params = [ + bigquery.ScalarQueryParameter('start_day', 'STRING', start_day), + bigquery.ScalarQueryParameter('last_day', 'STRING', last_day), + ] + + current_day = datetime.now().strftime('%Y-%m-%d') + is_current_month = last_day >= current_day + last_loaded_day = None + + if is_current_month: + # Only current month can have last 24 hours cost + # Last 24H in UTC time + # Find the last fully loaded day in the view + last_loaded_day = await self._last_loaded_day() + ( + query_params, + daily_cost_field, + daily_cost_join, + ) = self._prepare_daily_cost_subquery(field, query_params, last_loaded_day) + else: + # Do not calculate last 24H cost + daily_cost_field = ', NULL as daily_cost' + daily_cost_join = '' + + _query = f""" + SELECT + CASE WHEN month.field IS NULL THEN 'N/A' ELSE month.field END as field, + month.cost_category, + month.cost as monthly_cost + {daily_cost_field} + FROM + ( + SELECT + {field.value} as field, + cost_category, + SUM(cost) as cost + FROM + `{self.get_table_name()}` + WHERE {self._filter_to_optimise_query()} + AND invoice_month = @invoice_month + GROUP BY + field, + cost_category + HAVING cost > 0.1 + ) month + {daily_cost_join} + ORDER BY field ASC, daily_cost DESC, monthly_cost DESC; + """ + + query_params.append( + bigquery.ScalarQueryParameter('invoice_month', 'STRING', invoice_month) + ) + + return ( + is_current_month, + last_loaded_day, + self._execute_query(_query, query_params), + ) + + async def _append_total_running_cost( + self, + field: BillingColumn, + is_current_month: bool, + last_loaded_day: str | None, + total_monthly: dict, + total_daily: dict, + total_monthly_category: dict, + total_daily_category: dict, + results: list[BillingCostBudgetRecord], + ) -> list[BillingCostBudgetRecord]: + """ + Add total row: compute + storage to the results + """ + # construct ALL fields details + all_details = [] + for cat, mth_cost in total_monthly_category.items(): + all_details.append( + BillingCostDetailsRecord( + cost_group=abbrev_cost_category(cat), + cost_category=cat, + daily_cost=total_daily_category[cat] if is_current_month else None, + monthly_cost=mth_cost, + ) + ) + + # add total row: compute + storage + results.append( + BillingCostBudgetRecord( + field=f'{BillingColumn.generate_all_title(field)}', + total_monthly=( + total_monthly[COMPUTE]['ALL'] + total_monthly[STORAGE]['ALL'] + ), + total_daily=(total_daily[COMPUTE]['ALL'] + total_daily[STORAGE]['ALL']) + if is_current_month + else None, + compute_monthly=total_monthly[COMPUTE]['ALL'], + compute_daily=(total_daily[COMPUTE]['ALL']) + if is_current_month + else None, + storage_monthly=total_monthly[STORAGE]['ALL'], + storage_daily=(total_daily[STORAGE]['ALL']) + if is_current_month + else None, + details=all_details, + budget_spent=None, + budget=None, + last_loaded_day=last_loaded_day, + ) + ) + + return results + + async def _append_running_cost_records( + self, + field: BillingColumn, + is_current_month: bool, + last_loaded_day: str | None, + total_monthly: dict, + total_daily: dict, + field_details: dict, + results: list[BillingCostBudgetRecord], + ) -> list[BillingCostBudgetRecord]: + """ + Add all the selected field rows: compute + storage to the results + """ + # get budget map per gcp project + budgets_per_gcp_project = await self._budgets_by_gcp_project( + field, is_current_month + ) + + # add rows by field + for key, details in field_details.items(): + compute_daily = ( + total_daily[COMPUTE][key] if key in total_daily[COMPUTE] else 0 + ) + storage_daily = ( + total_daily[STORAGE][key] if key in total_daily[STORAGE] else 0 + ) + compute_monthly = ( + total_monthly[COMPUTE][key] if key in total_monthly[COMPUTE] else 0 + ) + storage_monthly = ( + total_monthly[STORAGE][key] if key in total_monthly[STORAGE] else 0 + ) + monthly = compute_monthly + storage_monthly + budget_monthly = budgets_per_gcp_project.get(key) + + results.append( + BillingCostBudgetRecord.from_json( + { + 'field': key, + 'total_monthly': monthly, + 'total_daily': (compute_daily + storage_daily) + if is_current_month + else None, + 'compute_monthly': compute_monthly, + 'compute_daily': compute_daily, + 'storage_monthly': storage_monthly, + 'storage_daily': storage_daily, + 'details': details, + 'budget_spent': 100 * monthly / budget_monthly + if budget_monthly + else None, + 'budget': budget_monthly, + 'last_loaded_day': last_loaded_day, + } + ) + ) + + return results + + def _prepare_order_by_string( + self, order_by: dict[BillingColumn, bool] | None + ) -> str: + """Prepare order by string""" + if not order_by: + return '' + + order_by_cols = [] + for order_field, reverse in order_by.items(): + col_name = str(order_field.value) + col_order = 'DESC' if reverse else 'ASC' + order_by_cols.append(f'{col_name} {col_order}') + + return f'ORDER BY {",".join(order_by_cols)}' if order_by_cols else '' + + def _prepare_aggregation( + self, query: BillingTotalCostQueryModel + ) -> tuple[str, str]: + """Prepare both fields for aggregation and group by string""" + # Get columns to group by + + # if group by is populated, then we need to group by day as well + grp_columns = ['day'] if query.group_by else [] + + for field in query.fields: + col_name = str(field.value) + if not BillingColumn.can_group_by(field): + # if the field cannot be grouped by, skip it + continue + + # append to potential columns to group by + grp_columns.append(col_name) + + fields_selected = ','.join( + (field.value for field in query.fields if field != BillingColumn.COST) + ) + + grp_selected = ','.join(grp_columns) + group_by = f'GROUP BY {grp_selected}' if query.group_by else '' + + return fields_selected, group_by + + def _prepare_labels_function(self, query: BillingTotalCostQueryModel): + if not query.filters: + return None + + if BillingColumn.LABELS in query.filters and isinstance( + query.filters[BillingColumn.LABELS], dict + ): + # prepare labels as function filters, parameterized both sides + func_filter = FunctionBQFilter( + name='getLabelValue', + implementation=""" + CREATE TEMP FUNCTION getLabelValue( + labels ARRAY>, label STRING + ) AS ( + (SELECT value FROM UNNEST(labels) WHERE key = label LIMIT 1) + ); + """, + ) + func_filter.to_sql( + BillingColumn.LABELS, + query.filters[BillingColumn.LABELS], + query.filters_op, + ) + return func_filter + + # otherwise + return None + + async def get_total_cost( + self, + query: BillingTotalCostQueryModel, + ) -> list[dict] | None: + """ + Get Total cost of selected fields for requested time interval from BQ views + """ + if not query.start_date or not query.end_date or not query.fields: + raise ValueError('Date and Fields are required') + + # Get columns to select and to group by + fields_selected, group_by = self._prepare_aggregation(query) + + # construct order by + order_by_str = self._prepare_order_by_string(query.order_by) + + # prepare grouping by time periods + time_group = TimeGroupingDetails('', '', '') + if query.time_periods or query.time_column: + time_group = prepare_time_periods(query) + + # overrides time specific fields with relevant time column name + query_filter = self._query_to_partitioned_filter(query) + + # prepare where string and SQL parameters + where_str, sql_parameters = query_filter.to_sql() + + # extract only BQ Query parameter, keys are not used in BQ SQL + # have to declare empty list first as linting is not happy + query_parameters: list[ + bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter + ] = [] + query_parameters.extend(sql_parameters.values()) + + # prepare labels as function filters if present + func_filter = self._prepare_labels_function(query) + if func_filter: + # extend where_str and query_parameters + query_parameters.extend(func_filter.func_sql_parameters) + + # now join Prepared Where with Labels Function Where + where_str = ' AND '.join([where_str, func_filter.func_where]) + + # if group by is populated, then we need SUM the cost, otherwise raw cost + cost_column = 'SUM(cost) as cost' if query.group_by else 'cost' + + if where_str: + # Where is not empty, prepend with WHERE + where_str = f'WHERE {where_str}' + + _query = f""" + {func_filter.fun_implementation if func_filter else ''} + + WITH t AS ( + SELECT {time_group.field}{time_group.separator} {fields_selected}, + {cost_column} + FROM `{self.get_table_name()}` + {where_str} + {group_by} + {order_by_str} + ) + SELECT {time_group.formula}{time_group.separator} {fields_selected}, cost FROM t + """ + + # append min cost condition + if query.min_cost: + _query += ' WHERE cost > @min_cost' + query_parameters.append( + bigquery.ScalarQueryParameter('min_cost', 'FLOAT64', query.min_cost) + ) + + # append LIMIT and OFFSET if present + if query.limit: + _query += ' LIMIT @limit_val' + query_parameters.append( + bigquery.ScalarQueryParameter('limit_val', 'INT64', query.limit) + ) + if query.offset: + _query += ' OFFSET @offset_val' + query_parameters.append( + bigquery.ScalarQueryParameter('offset_val', 'INT64', query.offset) + ) + + query_job_result = self._execute_query( + _query, query_parameters, results_as_list=False + ) + return self._convert_output(query_job_result) + + async def get_running_cost( + self, + field: BillingColumn, + invoice_month: str | None = None, + ) -> list[BillingCostBudgetRecord]: + """ + Get currently running cost of selected field + """ + + # accept only Topic, Dataset or Project at this stage + if field not in ( + BillingColumn.TOPIC, + BillingColumn.GCP_PROJECT, + BillingColumn.DATASET, + BillingColumn.STAGE, + BillingColumn.COMPUTE_CATEGORY, + BillingColumn.WDL_TASK_NAME, + BillingColumn.CROMWELL_SUB_WORKFLOW_NAME, + BillingColumn.NAMESPACE, + ): + raise ValueError( + 'Invalid field only topic, dataset, gcp-project, compute_category, ' + 'wdl_task_name, cromwell_sub_workflow_name & namespace are allowed' + ) + + ( + is_current_month, + last_loaded_day, + query_job_result, + ) = await self._execute_running_cost_query(field, invoice_month) + if not query_job_result: + # return empty list + return [] + + # prepare data + results: list[BillingCostBudgetRecord] = [] + + # reformat last_loaded_day if present + last_loaded_day = reformat_datetime( + last_loaded_day, '%Y-%m-%d %H:%M:%S+00:00', '%b %d' + ) + + total_monthly: dict[str, Counter[str]] = defaultdict(Counter) + total_daily: dict[str, Counter[str]] = defaultdict(Counter) + field_details: dict[str, list[Any]] = defaultdict(list) + total_monthly_category: Counter[str] = Counter() + total_daily_category: Counter[str] = Counter() + + for row in query_job_result: + if row.field not in field_details: + field_details[row.field] = [] + + cost_group = abbrev_cost_category(row.cost_category) + + field_details[row.field].append( + { + 'cost_group': cost_group, + 'cost_category': row.cost_category, + 'daily_cost': row.daily_cost if is_current_month else None, + 'monthly_cost': row.monthly_cost, + } + ) + + total_monthly_category[row.cost_category] += row.monthly_cost + if row.daily_cost: + total_daily_category[row.cost_category] += row.daily_cost + + # cost groups totals + total_monthly[cost_group]['ALL'] += row.monthly_cost + total_monthly[cost_group][row.field] += row.monthly_cost + if row.daily_cost and is_current_month: + total_daily[cost_group]['ALL'] += row.daily_cost + total_daily[cost_group][row.field] += row.daily_cost + + # add total row: compute + storage + results = await self._append_total_running_cost( + field, + is_current_month, + last_loaded_day, + total_monthly, + total_daily, + total_monthly_category, + total_daily_category, + results, + ) + + # add rest of the records: compute + storage + results = await self._append_running_cost_records( + field, + is_current_month, + last_loaded_day, + total_monthly, + total_daily, + field_details, + results, + ) + + return results diff --git a/db/python/tables/bq/billing_daily.py b/db/python/tables/bq/billing_daily.py new file mode 100644 index 000000000..14f21cef0 --- /dev/null +++ b/db/python/tables/bq/billing_daily.py @@ -0,0 +1,131 @@ +from google.cloud import bigquery + +from api.settings import BQ_AGGREG_VIEW +from db.python.tables.bq.billing_base import ( + BillingBaseTable, + time_optimisation_parameter, +) + + +class BillingDailyTable(BillingBaseTable): + """Billing Aggregated Daily Biq Query table""" + + table_name = BQ_AGGREG_VIEW + + def get_table_name(self): + """Get table name""" + return self.table_name + + async def get_topics(self): + """Get all topics in database""" + + # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query + # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL + # this day > filter is to limit the amount of data scanned, + # saving cost for running BQ + # aggregated views are partitioned by day + _query = f""" + SELECT DISTINCT topic + FROM `{self.table_name}` + WHERE day > TIMESTAMP_ADD( + CURRENT_TIMESTAMP(), INTERVAL @days DAY + ) + ORDER BY topic ASC; + """ + + query_parameters = [ + time_optimisation_parameter(), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + return [str(dict(row)['topic']) for row in query_job_result] + + # return empty list if no record found + return [] + + async def get_invoice_months(self): + """Get all invoice months in database + Aggregated views contain invoice_month field + """ + + _query = f""" + SELECT DISTINCT invoice_month + FROM `{self.table_name}` + ORDER BY invoice_month DESC; + """ + + query_job_result = self._execute_query(_query) + if query_job_result: + return [str(dict(row)['invoice_month']) for row in query_job_result] + + # return empty list if no record found + return [] + + async def get_cost_categories(self): + """Get all service description in database""" + + # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query + # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL + # this day > filter is to limit the amount of data scanned, + # saving cost for running BQ + # aggregated views are partitioned by day + _query = f""" + SELECT DISTINCT cost_category + FROM `{BQ_AGGREG_VIEW}` + WHERE day > TIMESTAMP_ADD( + CURRENT_TIMESTAMP(), INTERVAL @days DAY + ) + ORDER BY cost_category ASC; + """ + + query_parameters = [ + time_optimisation_parameter(), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + return [str(dict(row)['cost_category']) for row in query_job_result] + + # return empty list if no record found + return [] + + async def get_skus( + self, + limit: int | None = None, + offset: int | None = None, + ): + """Get all SKUs in database""" + + # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query + # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL + # this day > filter is to limit the amount of data scanned, + # saving cost for running BQ + # aggregated views are partitioned by day + _query = f""" + SELECT DISTINCT sku + FROM `{self.table_name}` + WHERE day > TIMESTAMP_ADD( + CURRENT_TIMESTAMP(), INTERVAL @days DAY + ) + ORDER BY sku ASC + """ + + # append LIMIT and OFFSET if present + if limit: + _query += ' LIMIT @limit_val' + if offset: + _query += ' OFFSET @offset_val' + + query_parameters = [ + time_optimisation_parameter(), + bigquery.ScalarQueryParameter('limit_val', 'INT64', limit), + bigquery.ScalarQueryParameter('offset_val', 'INT64', offset), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + return [str(dict(row)['sku']) for row in query_job_result] + + # return empty list if no record found + return [] diff --git a/db/python/tables/bq/billing_daily_extended.py b/db/python/tables/bq/billing_daily_extended.py new file mode 100644 index 000000000..009144911 --- /dev/null +++ b/db/python/tables/bq/billing_daily_extended.py @@ -0,0 +1,51 @@ +from api.settings import BQ_AGGREG_EXT_VIEW +from db.python.tables.bq.billing_base import ( + BillingBaseTable, + time_optimisation_parameter, +) +from models.models import BillingColumn + + +class BillingDailyExtendedTable(BillingBaseTable): + """Billing Aggregated Daily Extended Biq Query table""" + + table_name = BQ_AGGREG_EXT_VIEW + + def get_table_name(self): + """Get table name""" + return self.table_name + + async def get_extended_values(self, field: str): + """ + Get all extended values in database, for specified field. + Field is one of extended columns. + """ + + if field not in BillingColumn.extended_cols(): + raise ValueError('Invalid field value') + + # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query + # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL + # this day > filter is to limit the amount of data scanned, + # saving cost for running BQ + # aggregated views are partitioned by day + _query = f""" + SELECT DISTINCT {field} + FROM `{self.table_name}` + WHERE {field} IS NOT NULL + AND day > TIMESTAMP_ADD( + CURRENT_TIMESTAMP(), INTERVAL @days DAY + ) + ORDER BY 1 ASC; + """ + + query_parameters = [ + time_optimisation_parameter(), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + return [str(dict(row)[field]) for row in query_job_result] + + # return empty list if no record found + return [] diff --git a/db/python/tables/bq/billing_filter.py b/db/python/tables/bq/billing_filter.py new file mode 100644 index 000000000..9a379817f --- /dev/null +++ b/db/python/tables/bq/billing_filter.py @@ -0,0 +1,48 @@ +# pylint: disable=unused-import,too-many-instance-attributes + +import dataclasses +import datetime + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel + + +@dataclasses.dataclass +class BillingFilter(GenericBQFilterModel): + """ + Filter for billing, contains all possible attributes to filter on + """ + + # partition specific filters: + + # most billing views are parttioned by day + day: GenericBQFilter[datetime.datetime] = None + + # gpc table has different partition field: part_time + part_time: GenericBQFilter[datetime.datetime] = None + + # aggregate has different partition field: usage_end_time + usage_end_time: GenericBQFilter[datetime.datetime] = None + + # common filters: + invoice_month: GenericBQFilter[str] = None + + # min cost e.g. 0.01, if not set, will show all + cost: GenericBQFilter[float] = None + + ar_guid: GenericBQFilter[str] = None + gcp_project: GenericBQFilter[str] = None + topic: GenericBQFilter[str] = None + batch_id: GenericBQFilter[str] = None + cost_category: GenericBQFilter[str] = None + sku: GenericBQFilter[str] = None + dataset: GenericBQFilter[str] = None + sequencing_type: GenericBQFilter[str] = None + stage: GenericBQFilter[str] = None + sequencing_group: GenericBQFilter[str] = None + compute_category: GenericBQFilter[str] = None + cromwell_sub_workflow_name: GenericBQFilter[str] = None + cromwell_workflow_id: GenericBQFilter[str] = None + goog_pipelines_worker: GenericBQFilter[str] = None + wdl_task_name: GenericBQFilter[str] = None + namespace: GenericBQFilter[str] = None diff --git a/db/python/tables/bq/billing_gcp_daily.py b/db/python/tables/bq/billing_gcp_daily.py new file mode 100644 index 000000000..b765547c3 --- /dev/null +++ b/db/python/tables/bq/billing_gcp_daily.py @@ -0,0 +1,144 @@ +from datetime import datetime, timedelta + +from google.cloud import bigquery + +from api.settings import BQ_GCP_BILLING_VIEW +from db.python.tables.bq.billing_base import ( + BillingBaseTable, + time_optimisation_parameter, +) +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from models.models import BillingTotalCostQueryModel + + +class BillingGcpDailyTable(BillingBaseTable): + """Billing GCP Daily Big Query table""" + + table_name = BQ_GCP_BILLING_VIEW + + def get_table_name(self): + """Get table name""" + return self.table_name + + def _query_to_partitioned_filter( + self, query: BillingTotalCostQueryModel + ) -> BillingFilter: + """ + add extra filter to limit materialized view partition + Raw BQ billing table is partitioned by part_time (when data are loaded) + and not by end of usage time (day) + There is a delay up to 4-5 days between part_time and day + 7 days is added to be sure to get all data + """ + billing_filter = query.to_filter() + + # initial partition filter + billing_filter.part_time = GenericBQFilter[datetime]( + gte=datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None, + lte=(datetime.strptime(query.end_date, '%Y-%m-%d') + timedelta(days=7)) + if query.end_date + else None, + ) + # add day filter after partition filter is applied + billing_filter.day = GenericBQFilter[datetime]( + gte=datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None, + lte=datetime.strptime(query.end_date, '%Y-%m-%d') + if query.end_date + else None, + ) + return billing_filter + + async def _last_loaded_day(self): + """Get the most recent fully loaded day in db + Go 2 days back as the data is not always available for the current day + 1 day back is not enough + """ + + _query = f""" + SELECT TIMESTAMP_ADD(MAX(part_time), INTERVAL -2 DAY) as last_loaded_day + FROM `{self.table_name}` + WHERE part_time > TIMESTAMP_ADD( + CURRENT_TIMESTAMP(), INTERVAL @days DAY + ) + """ + + query_parameters = [ + time_optimisation_parameter(), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + return str(query_job_result[0].last_loaded_day) + + return None + + def _prepare_daily_cost_subquery(self, field, query_params, last_loaded_day): + """prepare daily cost subquery""" + + # add extra filter to limit materialized view partition + # Raw BQ billing table is partitioned by part_time (when data are loaded) + # and not by end of usage time (day) + # There is a delay up to 4-5 days between part_time and day + # 7 days is added to be sure to get all data + gcp_billing_optimise_filter = """ + AND part_time >= TIMESTAMP(@last_loaded_day) + AND part_time <= TIMESTAMP_ADD( + TIMESTAMP(@last_loaded_day), INTERVAL 7 DAY + ) + """ + + daily_cost_field = ', day.cost as daily_cost' + daily_cost_join = f"""LEFT JOIN ( + SELECT + {field.value} as field, + cost_category, + SUM(cost) as cost + FROM + `{self.get_table_name()}` + WHERE day = TIMESTAMP(@last_loaded_day) + {gcp_billing_optimise_filter} + GROUP BY + field, + cost_category + ) day + ON month.field = day.field + AND month.cost_category = day.cost_category + """ + + query_params.append( + bigquery.ScalarQueryParameter('last_loaded_day', 'STRING', last_loaded_day), + ) + return (query_params, daily_cost_field, daily_cost_join) + + async def get_gcp_projects(self): + """Get all GCP projects in database""" + + # cost of this BQ is 10MB on DEV is minimal, AU$ 0.000008 per query + # @days is defined by env variable BQ_DAYS_BACK_OPTIMAL + # this part_time > filter is to limit the amount of data scanned, + # saving cost for running BQ + _query = f""" + SELECT DISTINCT gcp_project + FROM `{self.table_name}` + WHERE part_time > TIMESTAMP_ADD( + CURRENT_TIMESTAMP(), INTERVAL @days DAY + ) + AND gcp_project IS NOT NULL + ORDER BY gcp_project ASC; + """ + + query_parameters = [ + time_optimisation_parameter(), + ] + query_job_result = self._execute_query(_query, query_parameters) + + if query_job_result: + return [str(dict(row)['gcp_project']) for row in query_job_result] + + # return empty list if no record found + return [] diff --git a/db/python/tables/bq/billing_raw.py b/db/python/tables/bq/billing_raw.py new file mode 100644 index 000000000..a82fa4eec --- /dev/null +++ b/db/python/tables/bq/billing_raw.py @@ -0,0 +1,36 @@ +from datetime import datetime + +from api.settings import BQ_AGGREG_RAW +from db.python.tables.bq.billing_base import BillingBaseTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from models.models import BillingTotalCostQueryModel + + +class BillingRawTable(BillingBaseTable): + """Billing Raw (Consolidated) Biq Query table""" + + table_name = BQ_AGGREG_RAW + + def get_table_name(self): + """Get table name""" + return self.table_name + + def _query_to_partitioned_filter( + self, query: BillingTotalCostQueryModel + ) -> BillingFilter: + """ + Raw BQ billing table is partitioned by usage_end_time + """ + billing_filter = query.to_filter() + + # initial partition filter + billing_filter.usage_end_time = GenericBQFilter[datetime]( + gte=datetime.strptime(query.start_date, '%Y-%m-%d') + if query.start_date + else None, + lte=datetime.strptime(query.end_date, '%Y-%m-%d') + if query.end_date + else None, + ) + return billing_filter diff --git a/db/python/tables/bq/function_bq_filter.py b/db/python/tables/bq/function_bq_filter.py new file mode 100644 index 000000000..f18f60211 --- /dev/null +++ b/db/python/tables/bq/function_bq_filter.py @@ -0,0 +1,109 @@ +from datetime import datetime +from enum import Enum +from typing import Any + +from google.cloud import bigquery + +from models.models import BillingColumn + + +class FunctionBQFilter: + """ + Function BigQuery filter where left site is a function call + In such case we need to parameterised values on both side of SQL + E.g. + + SELECT ... + FROM ... + WHERE getLabelValue(labels, 'batch_id') = '1234' + + In this case we have 2 string values which need to be parameterised + """ + + func_where = '' + func_sql_parameters: list[ + bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter + ] = [] + + def __init__(self, name: str, implementation: str): + self.func_name = name + self.fun_implementation = implementation + # param_id is a counter for parameterised values + self._param_id = 0 + + def to_sql( + self, + column_name: BillingColumn, + func_params: str | list[Any] | dict[Any, Any], + func_operator: str = None, + ) -> tuple[str, list[bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter]]: + """ + creates the left side of where : FUN(column_name, @params) + each of func_params convert to BQ parameter + combined multiple calls with provided operator, + if func_operator is None then AND is assumed by default + """ + values = [] + conditionals = [] + + if not isinstance(func_params, dict): + # Ignore func_params which are not dictionary for the time being + return '', [] + + for param_key, param_value in func_params.items(): + # parameterised both param_key and param_value + # e.g. this is raw SQL example: + # getLabelValue(labels, {param_key}) = {param_value} + self._param_id += 1 + key = f'param{self._param_id}' + val = f'value{self._param_id}' + # add param_key as parameterised BQ value + values.append(FunctionBQFilter._sql_value_prep(key, param_key)) + + # add param_value as parameterised BQ value + values.append(FunctionBQFilter._sql_value_prep(val, param_value)) + + # format as FUN(column_name, @param) = @value + conditionals.append( + ( + f'{self.func_name}({column_name.value},@{key}) = ' + f'{FunctionBQFilter._sql_cond_prep(val, param_value)}' + ) + ) + + if func_operator and func_operator == 'OR': + condition = ' OR '.join(conditionals) + else: + condition = ' AND '.join(conditionals) + + # set the class variables for later use + self.func_where = f'({condition})' + self.func_sql_parameters = values + return self.func_where, self.func_sql_parameters + + @staticmethod + def _sql_cond_prep(key: str, value: Any) -> str: + """ + By default '{key}' is used, + but for datetime it has to be wrapped in TIMESTAMP({key}) + """ + if isinstance(value, datetime): + return f'TIMESTAMP(@{key})' + + # otherwise as default + return f'@{key}' + + @staticmethod + def _sql_value_prep(key: str, value: Any) -> bigquery.ScalarQueryParameter: + """ """ + if isinstance(value, Enum): + return FunctionBQFilter._sql_value_prep(key, value.value) + if isinstance(value, int): + return bigquery.ScalarQueryParameter(key, 'INT64', value) + if isinstance(value, float): + return bigquery.ScalarQueryParameter(key, 'FLOAT64', value) + if isinstance(value, datetime): + return bigquery.ScalarQueryParameter(key, 'STRING', value) + + # otherwise as string parameter + return bigquery.ScalarQueryParameter(key, 'STRING', value) diff --git a/db/python/tables/bq/generic_bq_filter.py b/db/python/tables/bq/generic_bq_filter.py new file mode 100644 index 000000000..b0bfba973 --- /dev/null +++ b/db/python/tables/bq/generic_bq_filter.py @@ -0,0 +1,101 @@ +from datetime import datetime +from enum import Enum +from typing import Any + +from google.cloud import bigquery + +from db.python.utils import GenericFilter, T + + +class GenericBQFilter(GenericFilter[T]): + """ + Generic BigQuery filter is BQ specific filter class, based on GenericFilter + """ + + def to_sql( + self, column: str, column_name: str = None + ) -> tuple[str, dict[str, T | list[T] | Any | list[Any]]]: + """ + Convert to SQL, and avoid SQL injection + + """ + conditionals = [] + values: dict[str, T | list[T] | Any | list[Any]] = {} + _column_name = column_name or column + + if not isinstance(column, str): + raise ValueError(f'Column {_column_name!r} must be a string') + if self.eq is not None: + k = self.generate_field_name(_column_name + '_eq') + conditionals.append(f'{column} = {self._sql_cond_prep(k, self.eq)}') + values[k] = self._sql_value_prep(k, self.eq) + if self.in_ is not None: + if not isinstance(self.in_, list): + raise ValueError('IN filter must be a list') + if len(self.in_) == 1: + k = self.generate_field_name(_column_name + '_in_eq') + conditionals.append(f'{column} = {self._sql_cond_prep(k, self.in_[0])}') + values[k] = self._sql_value_prep(k, self.in_[0]) + else: + k = self.generate_field_name(_column_name + '_in') + conditionals.append(f'{column} IN ({self._sql_cond_prep(k, self.in_)})') + values[k] = self._sql_value_prep(k, self.in_) + if self.nin is not None: + if not isinstance(self.nin, list): + raise ValueError('NIN filter must be a list') + k = self.generate_field_name(column + '_nin') + conditionals.append(f'{column} NOT IN ({self._sql_cond_prep(k, self.nin)})') + values[k] = self._sql_value_prep(k, self.nin) + if self.gt is not None: + k = self.generate_field_name(column + '_gt') + conditionals.append(f'{column} > {self._sql_cond_prep(k, self.gt)}') + values[k] = self._sql_value_prep(k, self.gt) + if self.gte is not None: + k = self.generate_field_name(column + '_gte') + conditionals.append(f'{column} >= {self._sql_cond_prep(k, self.gte)}') + values[k] = self._sql_value_prep(k, self.gte) + if self.lt is not None: + k = self.generate_field_name(column + '_lt') + conditionals.append(f'{column} < {self._sql_cond_prep(k, self.lt)}') + values[k] = self._sql_value_prep(k, self.lt) + if self.lte is not None: + k = self.generate_field_name(column + '_lte') + conditionals.append(f'{column} <= {self._sql_cond_prep(k, self.lte)}') + values[k] = self._sql_value_prep(k, self.lte) + + return ' AND '.join(conditionals), values + + @staticmethod + def _sql_cond_prep(key, value) -> str: + """ + By default '@{key}' is used, + but for datetime it has to be wrapped in TIMESTAMP(@{k}) + """ + if isinstance(value, datetime): + return f'TIMESTAMP(@{key})' + + # otherwise as default + return f'@{key}' + + @staticmethod + def _sql_value_prep(key, value): + """ + Overrides the default _sql_value_prep to handle BQ parameters + """ + if isinstance(value, list): + return bigquery.ArrayQueryParameter( + key, 'STRING', ','.join([str(v) for v in value]) + ) + if isinstance(value, Enum): + return GenericBQFilter._sql_value_prep(key, value.value) + if isinstance(value, int): + return bigquery.ScalarQueryParameter(key, 'INT64', value) + if isinstance(value, float): + return bigquery.ScalarQueryParameter(key, 'FLOAT64', value) + if isinstance(value, datetime): + return bigquery.ScalarQueryParameter( + key, 'STRING', value.strftime('%Y-%m-%d %H:%M:%S') + ) + + # otherwise as string parameter + return bigquery.ScalarQueryParameter(key, 'STRING', value) diff --git a/db/python/tables/bq/generic_bq_filter_model.py b/db/python/tables/bq/generic_bq_filter_model.py new file mode 100644 index 000000000..c2736cc3a --- /dev/null +++ b/db/python/tables/bq/generic_bq_filter_model.py @@ -0,0 +1,111 @@ +import dataclasses +from typing import Any + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import GenericFilterModel + + +def prepare_bq_query_from_dict_field( + filter_, field_name, column_name +) -> tuple[list[str], dict[str, Any]]: + """ + Prepare a SQL query from a dict field, which is a dict of GenericFilters. + Usually this is a JSON field in the database that we want to query on. + """ + conditionals: list[str] = [] + values: dict[str, Any] = {} + for key, value in filter_.items(): + if not isinstance(value, GenericBQFilter): + raise ValueError(f'Filter {field_name} must be a GenericFilter') + if '"' in key: + raise ValueError('Meta key contains " character, which is not allowed') + if "'" in key: + raise ValueError("Meta key contains ' character, which is not allowed") + fconditionals, fvalues = value.to_sql( + f"JSON_EXTRACT({column_name}, '$.{key}')", + column_name=f'{column_name}_{key}', + ) + conditionals.append(fconditionals) + values.update(fvalues) + + return conditionals, values + + +@dataclasses.dataclass(kw_only=True) +class GenericBQFilterModel(GenericFilterModel): + """ + Class that contains fields of GenericBQFilters that can be used to filter + """ + + def __post_init__(self): + for field in dataclasses.fields(self): + value = getattr(self, field.name) + if value is None: + continue + + if isinstance(value, tuple) and len(value) == 1 and value[0] is None: + raise ValueError( + 'There is very likely a trailing comma on the end of ' + f'{self.__class__.__name__}.{field.name}. If you actually want a ' + 'tuple of length one with the value = (None,), then use ' + 'dataclasses.field(default_factory=lambda: (None,))' + ) + if isinstance(value, GenericBQFilter): + continue + + if isinstance(value, dict): + # make sure each field is a GenericFilter, or set it to be one, + # in this case it's always 'eq', never automatically in_ + new_value = { + k: v if isinstance(v, GenericBQFilter) else GenericBQFilter(eq=v) + for k, v in value.items() + } + setattr(self, field.name, new_value) + continue + + # lazily provided a value, which we'll correct + if isinstance(value, list): + setattr(self, field.name, GenericBQFilter(in_=value)) + else: + setattr(self, field.name, GenericBQFilter(eq=value)) + + def to_sql( + self, field_overrides: dict[str, Any] = None + ) -> tuple[str, dict[str, Any]]: + """Convert the model to SQL, and avoid SQL injection""" + _foverrides = field_overrides or {} + + # check for bad field_overrides + bad_field_overrides = set(_foverrides.keys()) - set( + f.name for f in dataclasses.fields(self) + ) + if bad_field_overrides: + raise ValueError( + f'Specified field overrides that were not used: {bad_field_overrides}' + ) + + fields = dataclasses.fields(self) + conditionals, values = [], {} + for field in fields: + fcolumn = _foverrides.get(field.name, field.name) + if filter_ := getattr(self, field.name): + if isinstance(filter_, dict): + meta_conditionals, meta_values = prepare_bq_query_from_dict_field( + filter_=filter_, field_name=field.name, column_name=fcolumn + ) + conditionals.extend(meta_conditionals) + values.update(meta_values) + elif isinstance(filter_, GenericBQFilter): + fconditionals, fvalues = filter_.to_sql(fcolumn) + conditionals.append(fconditionals) + values.update(fvalues) + else: + raise ValueError( + f'Filter {field.name} must be a GenericBQFilter or ' + 'dict[str, GenericBQFilter]' + ) + + if not conditionals: + return 'True', {} + + return ' AND '.join(filter(None, conditionals)), values diff --git a/db/python/tables/family.py b/db/python/tables/family.py index 853523b6f..18b13dd72 100644 --- a/db/python/tables/family.py +++ b/db/python/tables/family.py @@ -1,9 +1,10 @@ from collections import defaultdict -from typing import List, Optional, Set, Any, Dict +from typing import Any, Dict, List, Optional, Set -from db.python.connect import DbBase, NotFoundError -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase +from db.python.utils import NotFoundError from models.models.family import FamilyInternal +from models.models.project import ProjectId class FamilyTable(DbBase): @@ -47,7 +48,7 @@ async def get_families( JOIN family_participant ON family.id = family_participant.family_id """ - where.append(f'participant_id IN :pids') + where.append('participant_id IN :pids') values['pids'] = participant_ids if project or self.project: @@ -175,13 +176,12 @@ async def get_family_external_ids_by_participant_ids( async def update_family( self, id_: int, - external_id: str = None, - description: str = None, - coded_phenotype: str = None, - author: str = None, + external_id: str | None = None, + description: str | None = None, + coded_phenotype: str | None = None, ) -> bool: """Update values for a family""" - values: Dict[str, Any] = {'author': author or self.author} + values: Dict[str, Any] = {'audit_log_id': await self.audit_log_id()} if external_id: values['external_id'] = external_id if description: @@ -203,7 +203,6 @@ async def create_family( external_id: str, description: Optional[str], coded_phenotype: Optional[str], - author: str = None, project: ProjectId = None, ) -> int: """ @@ -213,7 +212,7 @@ async def create_family( 'external_id': external_id, 'description': description, 'coded_phenotype': coded_phenotype, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), 'project': project or self.project, } keys = list(updater.keys()) @@ -235,7 +234,6 @@ async def insert_or_update_multiple_families( descriptions: List[str], coded_phenotypes: List[Optional[str]], project: int = None, - author: str = None, ): """Upsert""" updater = [ @@ -243,7 +241,7 @@ async def insert_or_update_multiple_families( 'external_id': eid, 'description': descr, 'coded_phenotype': cph, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), 'project': project or self.project, } for eid, descr, cph in zip(external_ids, descriptions, coded_phenotypes) diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py index 8f2c730d6..8aab115a8 100644 --- a/db/python/tables/family_participant.py +++ b/db/python/tables/family_participant.py @@ -1,9 +1,9 @@ from collections import defaultdict -from typing import Tuple, List, Dict, Optional, Set, Any +from typing import Any, Dict, List, Optional, Set, Tuple -from db.python.connect import DbBase -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase from models.models.family import PedRowInternal +from models.models.project import ProjectId class FamilyParticipantTable(DbBase): @@ -21,7 +21,6 @@ async def create_row( maternal_id: int, affected: int, notes: str = None, - author=None, ) -> Tuple[int, int]: """ Create a new sample, and add it to database @@ -33,7 +32,7 @@ async def create_row( 'maternal_participant_id': maternal_id, 'affected': affected, 'notes': notes, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), } keys = list(updater.keys()) str_keys = ', '.join(keys) @@ -52,7 +51,6 @@ async def create_row( async def create_rows( self, rows: list[PedRowInternal], - author=None, ): """ Create many rows, dictionaries must have keys: @@ -76,7 +74,7 @@ async def create_rows( 'maternal_participant_id': row.maternal_id, 'affected': row.affected, 'notes': row.notes, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), } remapped_ds_by_keys[tuple(sorted(d.keys()))].append(d) @@ -225,12 +223,27 @@ async def delete_family_participant_row(self, family_id: int, participant_id: in if not participant_id or not family_id: return False + _update_before_delete = """ + UPDATE family_participant + SET audit_log_id = :audit_log_id + WHERE family_id = :family_id AND participant_id = :participant_id + """ + _query = """ -DELETE FROM family_participant -WHERE participant_id = :participant_id -AND family_id = :family_id + DELETE FROM family_participant + WHERE participant_id = :participant_id + AND family_id = :family_id """ + await self.connection.execute( + _update_before_delete, + { + 'family_id': family_id, + 'participant_id': participant_id, + 'audit_log_id': await self.audit_log_id(), + }, + ) + await self.connection.execute( _query, {'family_id': family_id, 'participant_id': participant_id} ) diff --git a/db/python/tables/participant.py b/db/python/tables/participant.py index 0b0b8eb9c..6f326f9cf 100644 --- a/db/python/tables/participant.py +++ b/db/python/tables/participant.py @@ -1,9 +1,10 @@ from collections import defaultdict from typing import Any -from db.python.connect import DbBase, NotFoundError -from db.python.utils import ProjectId, to_db_json +from db.python.tables.base import DbBase +from db.python.utils import NotFoundError, to_db_json from models.models.participant import ParticipantInternal +from models.models.project import ProjectId class ParticipantTable(DbBase): @@ -20,6 +21,7 @@ class ParticipantTable(DbBase): 'karyotype', 'meta', 'project', + 'audit_log_id', ] ) @@ -75,7 +77,6 @@ async def create_participant( reported_gender: str | None, karyotype: str | None, meta: dict | None, - author: str = None, project: ProjectId = None, ) -> int: """ @@ -84,9 +85,11 @@ async def create_participant( if not (project or self.project): raise ValueError('Must provide project to create participant') - _query = f""" -INSERT INTO participant (external_id, reported_sex, reported_gender, karyotype, meta, author, project) -VALUES (:external_id, :reported_sex, :reported_gender, :karyotype, :meta, :author, :project) + _query = """ +INSERT INTO participant + (external_id, reported_sex, reported_gender, karyotype, meta, audit_log_id, project) +VALUES + (:external_id, :reported_sex, :reported_gender, :karyotype, :meta, :audit_log_id, :project) RETURNING id """ @@ -98,7 +101,7 @@ async def create_participant( 'reported_gender': reported_gender, 'karyotype': karyotype, 'meta': to_db_json(meta or {}), - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), 'project': project or self.project, }, ) @@ -110,18 +113,17 @@ async def update_participants( reported_genders: list[str] | None, karyotypes: list[str] | None, metas: list[dict] | None, - author=None, ): """ Update many participants, expects that all lists contain the same number of values. You can't update selective fields on selective samples, if you provide metas, this function will update EVERY participant with the provided meta values. """ - _author = author or self.author - updaters = ['author = :author'] + updaters = ['audit_log_id = :audit_log_id'] + audit_log_id = await self.audit_log_id() values: dict[str, list[Any]] = { 'pid': participant_ids, - 'author': [_author] * len(participant_ids), + 'audit_log_id': [audit_log_id] * len(participant_ids), } if reported_sexes: updaters.append('reported_sex = :reported_sex') @@ -154,13 +156,12 @@ async def update_participant( reported_gender: str | None, karyotype: str | None, meta: dict | None, - author=None, ): """ Update participant """ - updaters = ['author = :author'] - fields = {'pid': participant_id, 'author': author or self.author} + updaters = ['audit_log_id = :audit_log_id'] + fields = {'pid': participant_id, 'audit_log_id': await self.audit_log_id()} if external_id: updaters.append('external_id = :external_id') @@ -248,7 +249,7 @@ async def get_participants_by_families( self, family_ids: list[int] ) -> tuple[set[ProjectId], dict[int, list[ParticipantInternal]]]: """Get list of participants keyed by families, duplicates results""" - _query = f""" + _query = """ SELECT project, fp.family_id, p.id, p.external_id, p.reported_sex, p.reported_gender, p.karyotype, p.meta FROM participant p INNER JOIN family_participant fp ON fp.participant_id = p.id @@ -271,10 +272,11 @@ async def update_many_participant_external_ids( """Update many participant external_ids through the {internal: external} map""" _query = """ UPDATE participant - SET external_id = :external_id + SET external_id = :external_id, audit_log_id = :audit_log_id WHERE id = :participant_id""" + audit_log_id = await self.audit_log_id() mapped_values = [ - {'participant_id': k, 'external_id': v} + {'participant_id': k, 'external_id': v, 'audit_log_id': audit_log_id} for k, v in internal_to_external_id.items() ] await self.connection.execute_many(_query, mapped_values) diff --git a/db/python/tables/participant_phenotype.py b/db/python/tables/participant_phenotype.py index 556617209..ea010588d 100644 --- a/db/python/tables/participant_phenotype.py +++ b/db/python/tables/participant_phenotype.py @@ -1,9 +1,8 @@ -from typing import List, Tuple, Any, Dict - import json from collections import defaultdict +from typing import Any, Dict, List, Tuple -from db.python.connect import DbBase +from db.python.tables.base import DbBase class ParticipantPhenotypeTable(DbBase): @@ -13,27 +12,28 @@ class ParticipantPhenotypeTable(DbBase): table_name = 'participant_phenotype' - async def add_key_value_rows( - self, rows: List[Tuple[int, str, Any]], author: str = None - ) -> None: + async def add_key_value_rows(self, rows: List[Tuple[int, str, Any]]) -> None: """ Create a new sample, and add it to database """ if not rows: return None - _query = f""" -INSERT INTO participant_phenotypes (participant_id, description, value, author, hpo_term) -VALUES (:participant_id, :description, :value, :author, 'DESCRIPTION') + _query = """ +INSERT INTO participant_phenotypes + (participant_id, description, value, audit_log_id, hpo_term) +VALUES + (:participant_id, :description, :value, :audit_log_id, 'DESCRIPTION') ON DUPLICATE KEY UPDATE - description=:description, value=:value, author=:author + description=:description, value=:value, audit_log_id=:audit_log_id """ + audit_log_id = await self.audit_log_id() formatted_rows = [ { 'participant_id': r[0], 'description': r[1], 'value': json.dumps(r[2]), - 'author': author or self.author, + 'audit_log_id': audit_log_id, } for r in rows ] diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 2bcd60c9f..d87c28960 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -5,16 +5,16 @@ from databases import Database from api.settings import is_all_access +from db.python.connect import Connection, SMConnections from db.python.utils import ( Forbidden, InternalError, NoProjectAccess, NotFoundError, - ProjectId, get_logger, to_db_json, ) -from models.models.project import Project +from models.models.project import Project, ProjectId logger = get_logger() @@ -38,14 +38,61 @@ def get_project_group_name(project_name: str, readonly: bool) -> str: return f'{project_name}-read' return f'{project_name}-write' - def __init__(self, connection: Database, allow_full_access=None): - if not isinstance(connection, Database): + def __init__( + self, + connection: Connection | None, + allow_full_access: bool | None = None, + database_connection: Database | None = None, + ): + self._connection = connection + if not database_connection and not connection: raise ValueError( - f'Invalid type connection, expected Database, got {type(connection)}, ' - 'did you forget to call connection.connection?' + 'Must call project permissions table with either a direct ' + 'database_connection or a fully formed connection' ) - self.connection: Database = connection - self.gtable = GroupTable(connection, allow_full_access=allow_full_access) + self.connection: Database = database_connection or connection.connection + self.gtable = GroupTable(self.connection, allow_full_access=allow_full_access) + + @staticmethod + async def get_project_connection( + *, + author: str, + project_name: str, + readonly: bool, + ar_guid: str, + on_behalf_of: str | None = None, + meta: dict[str, str] | None = None, + ): + """Get a db connection from a project and user""" + # maybe it makes sense to perform permission checks here too + logger.debug(f'Authenticate connection to {project_name} with {author!r}') + + conn = await SMConnections.get_made_connection() + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + + project = await pt.get_and_check_access_to_project_for_name( + user=author, project_name=project_name, readonly=readonly + ) + + return Connection( + connection=conn, + author=author, + project=project.id, + readonly=readonly, + on_behalf_of=on_behalf_of, + ar_guid=ar_guid, + meta=meta, + ) + + async def audit_log_id(self): + """ + Generate (or return) a audit_log_id by inserting a row into the database + """ + if not self._connection: + raise ValueError( + 'Cannot call audit_log_id without a fully formed connection' + ) + return await self._connection.audit_log_id() # region UNPROTECTED_GETS @@ -310,21 +357,24 @@ async def create_project( await self.check_project_creator_permissions(author) async with self.connection.transaction(): + audit_log_id = await self.audit_log_id() read_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=True) + self.get_project_group_name(project_name, readonly=True), + audit_log_id=audit_log_id, ) write_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=False) + self.get_project_group_name(project_name, readonly=False), + audit_log_id=audit_log_id, ) _query = """\ - INSERT INTO project (name, dataset, author, read_group_id, write_group_id) - VALUES (:name, :dataset, :author, :read_group_id, :write_group_id) + INSERT INTO project (name, dataset, audit_log_id, read_group_id, write_group_id) + VALUES (:name, :dataset, :audit_log_id, :read_group_id, :write_group_id) RETURNING ID""" values = { 'name': project_name, 'dataset': dataset_name, - 'author': author, + 'audit_log_id': await self.audit_log_id(), 'read_group_id': read_group_id, 'write_group_id': write_group_id, } @@ -342,9 +392,12 @@ async def update_project(self, project_name: str, update: dict, author: str): meta = update.get('meta') - fields: Dict[str, Any] = {'author': author, 'name': project_name} + fields: Dict[str, Any] = { + 'audit_log_id': await self.audit_log_id(), + 'name': project_name, + } - setters = ['author = :author'] + setters = ['audit_log_id = :audit_log_id'] if meta is not None and len(meta) > 0: fields['meta'] = to_db_json(meta) @@ -389,6 +442,16 @@ async def delete_project_data( INNER JOIN sample ON sample.id = sg.sample_id WHERE sample.project = :project ); +DELETE FROM analysis_sample WHERE sample_id in ( + SELECT s.id FROM sample s + WHERE s.project = :project +); +DELETE FROM analysis_sequencing_group WHERE analysis_id in ( + SELECT id FROM analysis WHERE project = :project +); +DELETE FROM analysis_sample WHERE analysis_id in ( + SELECT id FROM analysis WHERE project = :project +); DELETE FROM assay WHERE sample_id in (SELECT id FROM sample WHERE project = :project); DELETE FROM sequencing_group WHERE sample_id IN ( SELECT id FROM sample WHERE project = :project @@ -433,7 +496,9 @@ async def set_group_members(self, group_name: str, members: list[str], author: s f'User {author} does not have permission to add members to group {group_name}' ) group_id = await self.gtable.get_group_name_from_id(group_name) - await self.gtable.set_group_members(group_id, members, author=author) + await self.gtable.set_group_members( + group_id, members, audit_log_id=await self.audit_log_id() + ) # endregion CREATE / UPDATE @@ -563,16 +628,20 @@ async def check_which_groups_member_has( ) return set(r['gid'] for r in results) - async def create_group(self, name: str) -> int: + async def create_group(self, name: str, audit_log_id: int) -> int: """Create a new group""" _query = """ - INSERT INTO `group` (name) - VALUES (:name) + INSERT INTO `group` (name, audit_log_id) + VALUES (:name, :audit_log_id) RETURNING id """ - return await self.connection.fetch_val(_query, {'name': name}) + return await self.connection.fetch_val( + _query, {'name': name, 'audit_log_id': audit_log_id} + ) - async def set_group_members(self, group_id: int, members: list[str], author: str): + async def set_group_members( + self, group_id: int, members: list[str], audit_log_id: int + ): """ Set group members for a group (by id) """ @@ -585,11 +654,15 @@ async def set_group_members(self, group_id: int, members: list[str], author: str ) await self.connection.execute_many( """ - INSERT INTO group_member (group_id, member, author) - VALUES (:group_id, :member, :author) + INSERT INTO group_member (group_id, member, audit_log_id) + VALUES (:group_id, :member, :audit_log_id) """, [ - {'group_id': group_id, 'member': member, 'author': author} + { + 'group_id': group_id, + 'member': member, + 'audit_log_id': audit_log_id, + } for member in members ], ) diff --git a/db/python/tables/sample.py b/db/python/tables/sample.py index 0f15645a2..3457a2333 100644 --- a/db/python/tables/sample.py +++ b/db/python/tables/sample.py @@ -1,16 +1,17 @@ import asyncio -from datetime import date -from typing import Iterable, Any import dataclasses +from datetime import date +from typing import Any, Iterable -from db.python.connect import DbBase, NotFoundError +from db.python.tables.base import DbBase from db.python.utils import ( - to_db_json, - GenericFilterModel, GenericFilter, + GenericFilterModel, GenericMetaFilter, + NotFoundError, + to_db_json, ) -from db.python.tables.project import ProjectId +from models.models.project import ProjectId from models.models.sample import SampleInternal, sample_id_format @@ -47,6 +48,7 @@ class SampleTable(DbBase): 'active', 'type', 'project', + 'audit_log_id', ] # region GETS @@ -139,7 +141,6 @@ async def insert_sample( active: bool, meta: dict | None, participant_id: int | None, - author=None, project=None, ) -> int: """ @@ -152,7 +153,7 @@ async def insert_sample( ('meta', to_db_json(meta or {})), ('type', sample_type), ('active', active), - ('author', author or self.author), + ('audit_log_id', await self.audit_log_id()), ('project', project or self.project), ] @@ -179,16 +180,13 @@ async def update_sample( participant_id: int | None, external_id: str | None, type_: str | None, - author: str = None, active: bool = None, ): """Update a single sample""" - values: dict[str, Any] = { - 'author': author or self.author, - } + values: dict[str, Any] = {'audit_log_id': await self.audit_log_id()} fields = [ - 'author = :author', + 'audit_log_id = :audit_log_id', ] if participant_id: values['participant_id'] = participant_id @@ -220,7 +218,6 @@ async def merge_samples( self, id_keep: int = None, id_merge: int = None, - author: str = None, ): """Merge two samples together""" sid_merge = sample_id_format(id_merge) @@ -261,34 +258,43 @@ def dict_merge(meta1, meta2): meta_original.get('merged_from'), sid_merge ) meta: dict[str, Any] = dict_merge(meta_original, sample_merge.meta) - + audit_log_id = await self.audit_log_id() values: dict[str, Any] = { 'sample': { 'id': id_keep, - 'author': author or self.author, + 'audit_log_id': audit_log_id, 'meta': to_db_json(meta), }, - 'ids': {'id_keep': id_keep, 'id_merge': id_merge}, + 'ids': { + 'id_keep': id_keep, + 'id_merge': id_merge, + 'audit_log_id': audit_log_id, + }, } _query = """ UPDATE sample - SET author = :author, + SET audit_log_id = :audit_log_id, meta = :meta WHERE id = :id """ - _query_seqs = f""" - UPDATE sample_sequencing - SET sample_id = :id_keep + _query_seqs = """ + UPDATE assay + SET sample_id = :id_keep, audit_log_id = :audit_log_id WHERE sample_id = :id_merge """ # TODO: merge sequencing groups I guess? - _query_analyses = f""" - UPDATE analysis_sample - SET sample_id = :id_keep + _query_analyses = """ + UPDATE analysis_sequencing_group + SET sample_id = :id_keep, audit_log_id = :audit_log_id WHERE sample_id = :id_merge """ - _del_sample = f""" + _query_update_sample_with_audit_log = """ + UPDATE sample + SET audit_log_id = :audit_log_id + WHERE id = :id_merge + """ + _del_sample = """ DELETE FROM sample WHERE id = :id_merge """ @@ -297,11 +303,16 @@ def dict_merge(meta1, meta2): await self.connection.execute(_query, {**values['sample']}) await self.connection.execute(_query_seqs, {**values['ids']}) await self.connection.execute(_query_analyses, {**values['ids']}) - await self.connection.execute(_del_sample, {'id_merge': id_merge}) + await self.connection.execute( + _query_update_sample_with_audit_log, + {'id_merge': id_merge, 'audit_log_id': audit_log_id}, + ) + await self.connection.execute( + _del_sample, {'id_merge': id_merge, 'audit_log_id': audit_log_id} + ) project, new_sample = await self.get_sample_by_id(id_keep) new_sample.project = project - new_sample.author = author or self.author return new_sample @@ -312,9 +323,15 @@ async def update_many_participant_ids( Update participant IDs for many samples Expected len(ids) == len(participant_ids) """ - _query = 'UPDATE sample SET participant_id=:participant_id WHERE id = :id' + _query = """ + UPDATE sample + SET participant_id=:participant_id, audit_log_id = :audit_log_id + WHERE id = :id + """ + audit_log_id = await self.audit_log_id() values = [ - {'id': i, 'participant_id': pid} for i, pid in zip(ids, participant_ids) + {'id': i, 'participant_id': pid, 'audit_log_id': audit_log_id} + for i, pid in zip(ids, participant_ids) ] await self.connection.execute_many(_query, values) @@ -420,6 +437,7 @@ async def get_history_of_sample(self, id_: int): 'type', 'project', 'author', + 'audit_log_id', ] keys_str = ', '.join(keys) _query = f'SELECT {keys_str} FROM sample FOR SYSTEM_TIME ALL WHERE id = :id' diff --git a/db/python/tables/sequencing_group.py b/db/python/tables/sequencing_group.py index 0f315b191..d154f43e5 100644 --- a/db/python/tables/sequencing_group.py +++ b/db/python/tables/sequencing_group.py @@ -4,14 +4,16 @@ from datetime import date from typing import Any -from db.python.connect import DbBase, NoOpAenter, NotFoundError +from db.python.tables.base import DbBase from db.python.utils import ( GenericFilter, GenericFilterModel, GenericMetaFilter, - ProjectId, + NoOpAenter, + NotFoundError, to_db_json, ) +from models.models.project import ProjectId from models.models.sequencing_group import ( SequencingGroupInternal, SequencingGroupInternalId, @@ -374,7 +376,6 @@ async def create_sequencing_group( platform: str, assay_ids: list[int], meta: dict = None, - author: str = None, open_transaction=True, ) -> int: """Create sequence group""" @@ -405,16 +406,17 @@ async def create_sequencing_group( _query = """ INSERT INTO sequencing_group - (sample_id, type, technology, platform, meta, author, archived) - VALUES (:sample_id, :type, :technology, :platform, :meta, :author, false) + (sample_id, type, technology, platform, meta, audit_log_id, archived) + VALUES + (:sample_id, :type, :technology, :platform, :meta, :audit_log_id, false) RETURNING id; """ _seqg_linker_query = """ INSERT INTO sequencing_group_assay - (sequencing_group_id, assay_id, author) + (sequencing_group_id, assay_id, audit_log_id) VALUES - (:seqgroup, :assayid, :author) + (:seqgroup, :assayid, :audit_log_id) """ values = { @@ -437,13 +439,13 @@ async def create_sequencing_group( id_of_seq_group = await self.connection.fetch_val( _query, - {**values, 'author': author or self.author}, + {**values, 'audit_log_id': await self.audit_log_id()}, ) assay_id_insert_values = [ { 'seqgroup': id_of_seq_group, 'assayid': s, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), } for s in assay_ids ] @@ -459,9 +461,10 @@ async def update_sequencing_group( """ Update meta / platform on sequencing_group """ - updaters = [] + updaters = ['audit_log_id = :audit_log_id'] values: dict[str, Any] = { 'seqgid': sequencing_group_id, + 'audit_log_id': await self.audit_log_id(), } if meta: @@ -486,21 +489,28 @@ async def archive_sequencing_groups(self, sequencing_group_id: list[int]): """ _query = """ UPDATE sequencing_group - SET archived = 1, author = :author + SET archived = 1, audit_log_id = :audit_log_id WHERE id = :sequencing_group_id; """ # do this so we can reuse the sequencing_group_ids _external_id_query = """ UPDATE sequencing_group_external_id - SET nullIfInactive = NULL + SET nullIfInactive = NULL, audit_log_id = :audit_log_id WHERE sequencing_group_id = :sequencing_group_id; """ await self.connection.execute( - _query, {'sequencing_group_id': sequencing_group_id, 'author': self.author} + _query, + { + 'sequencing_group_id': sequencing_group_id, + 'audit_log_id': await self.audit_log_id(), + }, ) await self.connection.execute( _external_id_query, - {'sequencing_group_id': sequencing_group_id}, + { + 'sequencing_group_id': sequencing_group_id, + 'audit_log_id': await self.audit_log_id(), + }, ) async def get_type_numbers_for_project(self, project) -> dict[str, int]: diff --git a/db/python/utils.py b/db/python/utils.py index 45979c9c1..b714081d1 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -7,7 +7,6 @@ from typing import Any, Generic, Sequence, TypeVar T = TypeVar('T') -ProjectId = int levels_map = {'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING': logging.WARNING} diff --git a/deploy/python/version.txt b/deploy/python/version.txt index f22d756da..28179fc1f 100644 --- a/deploy/python/version.txt +++ b/deploy/python/version.txt @@ -1 +1 @@ -6.5.0 +6.6.2 diff --git a/metamist/graphql/__init__.py b/metamist/graphql/__init__.py index 44b16ba7b..2cd00b996 100644 --- a/metamist/graphql/__init__.py +++ b/metamist/graphql/__init__.py @@ -9,7 +9,9 @@ from gql import Client, gql as gql_constructor from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.aiohttp import log as aiohttp_logger from gql.transport.requests import RequestsHTTPTransport +from gql.transport.requests import log as requests_logger from cpg_utils.cloud import get_google_identity_token @@ -54,13 +56,18 @@ def configure_sync_client( if _sync_client and not force_recreate: return _sync_client - token = auth_token or get_google_identity_token( - target_audience=metamist.configuration.sm_url - ) - transport = RequestsHTTPTransport( - url=url or get_sm_url(), - headers={'Authorization': f'Bearer {token}'}, - ) + env = os.getenv('SM_ENVIRONMENT', 'PRODUCTION').lower() + if env == 'local': + transport = RequestsHTTPTransport(url=url or get_sm_url()) + else: + token = auth_token or get_google_identity_token( + target_audience=metamist.configuration.sm_url + ) + transport = RequestsHTTPTransport( + url=url or get_sm_url(), + headers={'Authorization': f'Bearer {token}'}, + ) + _sync_client = Client( transport=transport, schema=schema, fetch_schema_from_transport=schema is None ) @@ -78,13 +85,18 @@ async def configure_async_client( if _async_client and not force_recreate: return _async_client - token = auth_token or get_google_identity_token( - target_audience=metamist.configuration.sm_url - ) - transport = AIOHTTPTransport( - url=url or get_sm_url(), - headers={'Authorization': f'Bearer {token}'}, - ) + env = os.getenv('SM_ENVIRONMENT', 'PRODUCTION').lower() + if env == 'local': + transport = AIOHTTPTransport(url=url or get_sm_url()) + else: + token = auth_token or get_google_identity_token( + target_audience=metamist.configuration.sm_url + ) + transport = AIOHTTPTransport( + url=url or get_sm_url(), + headers={'Authorization': f'Bearer {token}'}, + ) + _async_client = Client( transport=transport, schema=schema, fetch_schema_from_transport=schema is None ) @@ -126,26 +138,39 @@ def validate(doc: DocumentNode, client=None, use_local_schema=False): # use older style typing to broaden supported Python versions def query( - _query: str | DocumentNode, variables: Dict = None, client: Client = None + _query: str | DocumentNode, variables: Dict = None, client: Client = None, log_response: bool = False ) -> Dict[str, Any]: """Query the metamist GraphQL API""" if variables is None: variables = {} + # disable logging for gql + current_level = aiohttp_logger.level + if not log_response: + requests_logger.setLevel('WARNING') + response = (client or configure_sync_client()).execute_sync( _query if isinstance(_query, DocumentNode) else gql(_query), variable_values=variables, ) + + if not log_response: + requests_logger.setLevel(current_level) return response async def query_async( - _query: str | DocumentNode, variables: Dict = None, client: Client = None + _query: str | DocumentNode, variables: Dict = None, client: Client = None, log_response: bool = False ) -> Dict[str, Any]: """Asynchronously query the Metamist GraphQL API""" if variables is None: variables = {} + # disable logging for gql + current_level = aiohttp_logger.level + if log_response: + aiohttp_logger.setLevel('WARNING') + if not client: client = await configure_async_client() @@ -153,4 +178,8 @@ async def query_async( _query if isinstance(_query, DocumentNode) else gql(_query), variable_values=variables, ) + + if log_response: + aiohttp_logger.setLevel(current_level) + return response diff --git a/models/base.py b/models/base.py index 389c38a56..2b728e2af 100644 --- a/models/base.py +++ b/models/base.py @@ -7,3 +7,15 @@ class SMBase(BaseModel): """Base object for all models""" + + +def parse_sql_bool(val: str | int | bytes) -> bool | None: + """Parse a string from a sql bool""" + if val is None: + return None + if isinstance(val, int): + return bool(val) + if isinstance(val, bytes): + return bool(ord(val)) + + raise ValueError(f'Unknown type for active: {type(val)}') diff --git a/models/enums/__init__.py b/models/enums/__init__.py index 14bcb9d5a..14f047786 100644 --- a/models/enums/__init__.py +++ b/models/enums/__init__.py @@ -1,3 +1,4 @@ from models.enums.analysis import AnalysisStatus +from models.enums.billing import BillingSource, BillingTimeColumn, BillingTimePeriods from models.enums.search import SearchResponseType from models.enums.web import MetaSearchEntityPrefix diff --git a/models/enums/billing.py b/models/enums/billing.py new file mode 100644 index 000000000..efcff271c --- /dev/null +++ b/models/enums/billing.py @@ -0,0 +1,31 @@ +from enum import Enum + + +class BillingSource(str, Enum): + """List of billing sources""" + + RAW = 'raw' + AGGREGATE = 'aggregate' + EXTENDED = 'extended' + BUDGET = 'budget' + GCP_BILLING = 'gcp_billing' + BATCHES = 'batches' + + +class BillingTimePeriods(str, Enum): + """List of billing grouping time periods""" + + # grouping time periods + DAY = 'day' + WEEK = 'week' + MONTH = 'month' + INVOICE_MONTH = 'invoice_month' + + +class BillingTimeColumn(str, Enum): + """List of billing time columns""" + + DAY = 'day' + USAGE_START_TIME = 'usage_start_time' + USAGE_END_TIME = 'usage_end_time' + EXPORT_TIME = 'export_time' diff --git a/models/models/__init__.py b/models/models/__init__.py index ab2e0d707..911974169 100644 --- a/models/models/__init__.py +++ b/models/models/__init__.py @@ -1,3 +1,4 @@ +from models.base import parse_sql_bool from models.models.analysis import ( Analysis, AnalysisInternal, @@ -9,13 +10,15 @@ SequencingGroupSizeModel, ) from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal +from models.models.audit_log import AuditLogId, AuditLogInternal from models.models.billing import ( BillingColumn, - BillingRowRecord, - BillingTotalCostQueryModel, - BillingTotalCostRecord, BillingCostBudgetRecord, BillingCostDetailsRecord, + BillingHailBatchCostRecord, + BillingInternal, + BillingTotalCostQueryModel, + BillingTotalCostRecord, ) from models.models.cohort import Cohort from models.models.family import ( @@ -33,7 +36,7 @@ ParticipantUpsert, ParticipantUpsertInternal, ) -from models.models.project import Project +from models.models.project import Project, ProjectId from models.models.sample import ( NestedSample, NestedSampleInternal, diff --git a/models/models/analysis.py b/models/models/analysis.py index 4d33bfa04..fb6e3152d 100644 --- a/models/models/analysis.py +++ b/models/models/analysis.py @@ -85,10 +85,10 @@ def to_external(self): class Analysis(BaseModel): """Model for Analysis""" - id: int | None type: str status: AnalysisStatus - output: str = None + id: int | None = None + output: str | None = None sequencing_group_ids: list[str] = [] author: str | None = None timestamp_completed: str | None = None diff --git a/models/models/assay.py b/models/models/assay.py index 04658ad44..139d60327 100644 --- a/models/models/assay.py +++ b/models/models/assay.py @@ -12,7 +12,7 @@ class AssayInternal(SMBase): sample_id: int meta: dict[str, Any] | None type: str - external_ids: dict[str, str] | None = None + external_ids: dict[str, str] | None = {} def __repr__(self): return ', '.join(f'{k}={v}' for k, v in vars(self).items()) diff --git a/models/models/audit_log.py b/models/models/audit_log.py new file mode 100644 index 000000000..20969dd50 --- /dev/null +++ b/models/models/audit_log.py @@ -0,0 +1,31 @@ +import datetime +import json + +from models.base import SMBase +from models.models.project import ProjectId + +AuditLogId = int + + +class AuditLogInternal(SMBase): + """ + Model for audit_log + """ + + id: AuditLogId + timestamp: datetime.datetime + author: str + auth_project: ProjectId + on_behalf_of: str | None + ar_guid: str | None + comment: str | None + meta: dict | None + + @staticmethod + def from_db(d: dict): + """Take DB mapping object, and return SampleSequencing""" + meta = {} + if 'meta' in d: + meta = json.loads(d.pop('meta')) + + return AuditLogInternal(meta=meta, **d) diff --git a/models/models/billing.py b/models/models/billing.py index 481ea77ce..9587ed44c 100644 --- a/models/models/billing.py +++ b/models/models/billing.py @@ -1,135 +1,75 @@ import datetime from enum import Enum -from db.python.tables.billing import BillingFilter -from db.python.utils import GenericFilter - +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter from models.base import SMBase +from models.enums.billing import BillingSource, BillingTimeColumn, BillingTimePeriods -class BillingQueryModel(SMBase): - """Used to query for billing""" - - # topic is cluster index, provide some values to make it more efficient - topic: list[str] | None = None - - # make date required, to avoid full table scan - date: str +class BillingInternal(SMBase): + """Model for Analysis""" - cost_category: list[str] | None = None - - def to_filter(self) -> BillingFilter: - """Convert to internal analysis filter""" - return BillingFilter( - topic=GenericFilter(in_=self.topic) if self.topic else None, - date=GenericFilter(eq=self.date), - cost_category=GenericFilter(in_=self.cost_category) - if self.cost_category - else None, - ) - - def __hash__(self): - """Create hash for this object to use in caching""" - return hash(self.json()) - - -class BillingRowRecord(SMBase): - """Return class for the Billing record""" - - id: str + id: str | None + ar_guid: str | None + gcp_project: str | None topic: str | None - service_id: str | None - service_description: str | None - - sku_id: str | None - sku_description: str | None - - usage_start_time: datetime.datetime | None - usage_end_time: datetime.datetime | None - - gcp_project_id: str | None - gcp_project_number: str | None - gcp_project_name: str | None - - # labels - dataset: str | None batch_id: str | None - job_id: str | None - batch_name: str | None - sequencing_type: str | None - stage: str | None - sequencing_group: str | None - - export_time: datetime.datetime | None - cost: str | None - currency: str | None - currency_conversion_rate: str | None - invoice_month: str | None - cost_type: str | None - - class Config: - """Config for BillingRowRecord Response""" - - orm_mode = True + cost_category: str | None + cost: float | None + day: datetime.date | None @staticmethod - def from_json(record): - """Create BillingRowRecord from json""" - - record['service'] = record['service'] if record['service'] else {} - record['project'] = record['project'] if record['project'] else {} - record['invoice'] = record['invoice'] if record['invoice'] else {} - record['sku'] = record['sku'] if record['sku'] else {} - - labels = {} - - if record['labels']: - for lbl in record['labels']: - labels[lbl['key']] = lbl['value'] - - record['labels'] = labels - - return BillingRowRecord( - id=record['id'], - topic=record['topic'], - service_id=record['service'].get('id'), - service_description=record['service'].get('description'), - sku_id=record['sku'].get('id'), - sku_description=record['sku'].get('description'), - usage_start_time=record['usage_start_time'], - usage_end_time=record['usage_end_time'], - gcp_project_id=record['project'].get('id'), - gcp_project_number=record['project'].get('number'), - gcp_project_name=record['project'].get('name'), - # labels - dataset=record['labels'].get('dataset'), - batch_id=record['labels'].get('batch_id'), - job_id=record['labels'].get('job_id'), - batch_name=record['labels'].get('batch_name'), - sequencing_type=record['labels'].get('sequencing_type'), - stage=record['labels'].get('stage'), - sequencing_group=record['labels'].get('sequencing_group'), - export_time=record['export_time'], - cost=record['cost'], - currency=record['currency'], - currency_conversion_rate=record['currency_conversion_rate'], - invoice_month=record['invoice'].get('month', ''), - cost_type=record['cost_type'], + def from_db(**kwargs): + """ + Convert from db keys, mainly converting id to id_ + """ + return BillingInternal( + id=kwargs.get('id'), + ar_guid=kwargs.get('ar_guid', kwargs.get('ar-guid')), + gcp_project=kwargs.get('gcp_project'), + topic=kwargs.get('topic'), + batch_id=kwargs.get('batch_id'), + cost_category=kwargs.get('cost_category'), + cost=kwargs.get('cost'), + day=kwargs.get('day'), ) class BillingColumn(str, Enum): """List of billing columns""" - # base view columns + # raw view columns + ID = 'id' TOPIC = 'topic' - PROJECT = 'gcp_project' + SERVICE = 'service' + SKU = 'sku' + USAGE_START_TIME = 'usage_start_time' + USAGE_END_TIME = 'usage_end_time' + PROJECT = 'project' + LABELS = 'labels' + SYSTEM_LABELS = 'system_labels' + LOCATION = 'location' + EXPORT_TIME = 'export_time' + COST = 'cost' + CURRENCY = 'currency' + CURRENCY_CONVERSION_RATE = 'currency_conversion_rate' + USAGE = 'usage' + CREDITS = 'credits' + INVOICE = 'invoice' + COST_TYPE = 'cost_type' + ADJUSTMENT_INFO = 'adjustment_info' + + # base view columns + # TOPIC = 'topic' + # SKU = 'sku' + # CURRENCY = 'currency' + # COST = 'cost' + # LABELS = 'labels' + GCP_PROJECT = 'gcp_project' DAY = 'day' COST_CATEGORY = 'cost_category' - SKU = 'sku' AR_GUID = 'ar_guid' - CURRENCY = 'currency' - COST = 'cost' INVOICE_MONTH = 'invoice_month' # extended, filtered view columns @@ -138,23 +78,122 @@ class BillingColumn(str, Enum): SEQUENCING_TYPE = 'sequencing_type' STAGE = 'stage' SEQUENCING_GROUP = 'sequencing_group' + COMPUTE_CATEGORY = 'compute_category' + CROMWELL_SUB_WORKFLOW_NAME = 'cromwell_sub_workflow_name' + CROMWELL_WORKFLOW_ID = 'cromwell_workflow_id' + GOOG_PIPELINES_WORKER = 'goog_pipelines_worker' + WDL_TASK_NAME = 'wdl_task_name' + NAMESPACE = 'namespace' + + @classmethod + def can_group_by(cls, value: 'BillingColumn') -> bool: + """ + Return True if column can be grouped by + TODO: If any new columns are added above and cannot be in a group by, add them here + This could be record, array or struct type + """ + return value not in ( + BillingColumn.COST, + BillingColumn.SERVICE, + # BillingColumn.SKU, + BillingColumn.PROJECT, + BillingColumn.LABELS, + BillingColumn.SYSTEM_LABELS, + BillingColumn.LOCATION, + BillingColumn.USAGE, + BillingColumn.CREDITS, + BillingColumn.INVOICE, + BillingColumn.ADJUSTMENT_INFO, + ) + + @classmethod + def is_extended_column(cls, value: 'BillingColumn') -> bool: + """Return True if column is extended""" + return value in ( + BillingColumn.DATASET, + BillingColumn.BATCH_ID, + BillingColumn.SEQUENCING_TYPE, + BillingColumn.STAGE, + BillingColumn.SEQUENCING_GROUP, + BillingColumn.COMPUTE_CATEGORY, + BillingColumn.CROMWELL_SUB_WORKFLOW_NAME, + BillingColumn.CROMWELL_WORKFLOW_ID, + BillingColumn.GOOG_PIPELINES_WORKER, + BillingColumn.WDL_TASK_NAME, + BillingColumn.NAMESPACE, + ) + + @classmethod + def str_to_enum(cls, value: str) -> 'BillingColumn': + """Convert string to enum""" + # all column names have underscore in SQL, but dash in UI / stored data + adjusted_value = value.replace('-', '_') + str_to_enum = {v.value: v for k, v in BillingColumn.__members__.items()} + return str_to_enum[adjusted_value] + + @classmethod + def raw_cols(cls) -> list[str]: + """Return list of raw column names""" + return [ + BillingColumn.ID.value, + BillingColumn.TOPIC.value, + BillingColumn.SERVICE.value, + BillingColumn.SKU.value, + BillingColumn.USAGE_START_TIME.value, + BillingColumn.USAGE_END_TIME.value, + BillingColumn.PROJECT.value, + BillingColumn.LABELS.value, + BillingColumn.SYSTEM_LABELS.value, + BillingColumn.LOCATION.value, + BillingColumn.EXPORT_TIME.value, + BillingColumn.COST.value, + BillingColumn.CURRENCY.value, + BillingColumn.CURRENCY_CONVERSION_RATE.value, + BillingColumn.USAGE.value, + BillingColumn.CREDITS.value, + BillingColumn.INVOICE.value, + BillingColumn.COST_TYPE.value, + BillingColumn.ADJUSTMENT_INFO.value, + ] + + @classmethod + def standard_cols(cls) -> list[str]: + """Return list of standard column names""" + return [ + BillingColumn.TOPIC.value, + BillingColumn.GCP_PROJECT.value, + BillingColumn.SKU.value, + BillingColumn.CURRENCY.value, + BillingColumn.COST.value, + BillingColumn.LABELS.value, + BillingColumn.DAY.value, + BillingColumn.COST_CATEGORY.value, + BillingColumn.AR_GUID.value, + BillingColumn.INVOICE_MONTH.value, + ] @classmethod def extended_cols(cls) -> list[str]: """Return list of extended column names""" return [ - 'dataset', - 'batch_id', - 'sequencing_type', - 'stage', - 'sequencing_group', - 'ar_guid' + BillingColumn.DATASET.value, + BillingColumn.BATCH_ID.value, + BillingColumn.SEQUENCING_TYPE.value, + BillingColumn.STAGE.value, + BillingColumn.SEQUENCING_GROUP.value, + BillingColumn.AR_GUID.value, + BillingColumn.COMPUTE_CATEGORY.value, + BillingColumn.CROMWELL_SUB_WORKFLOW_NAME.value, + BillingColumn.CROMWELL_WORKFLOW_ID.value, + BillingColumn.GOOG_PIPELINES_WORKER.value, + BillingColumn.WDL_TASK_NAME.value, + BillingColumn.NAMESPACE.value, ] @staticmethod def generate_all_title(record) -> str: """Generate Column as All Title""" - if record == BillingColumn.PROJECT: + if record == BillingColumn.GCP_PROJECT: return 'All GCP Projects' return f'All {record.title()}s' @@ -170,20 +209,44 @@ class BillingTotalCostQueryModel(SMBase): fields: list[BillingColumn] start_date: str end_date: str - # optional, can be aggregate or gcp_billing - source: str | None = None + # optional, can be raw, aggregate or gcp_billing + source: BillingSource | None = None # optional - filters: dict[BillingColumn, str] | None = None + filters: dict[BillingColumn, str | list | dict] | None = None + # optional, AND or OR + filters_op: str | None = None + group_by: bool = True + # order by, reverse= TRUE for DESC, FALSE for ASC order_by: dict[BillingColumn, bool] | None = None limit: int | None = None offset: int | None = None + # default to day, can be day, week, month, invoice_month + time_column: BillingTimeColumn | None = None + time_periods: BillingTimePeriods | None = None + + # optional, show the min cost, e.g. 0.01, if not set, will show all + min_cost: float | None = None + def __hash__(self): """Create hash for this object to use in caching""" return hash(self.json()) + def to_filter(self) -> BillingFilter: + """ + Convert to internal analysis filter + """ + billing_filter = BillingFilter() + if self.filters: + # add filters as attributes + for fk, fv in self.filters.items(): + # fk is BillColumn, fv is value + setattr(billing_filter, fk.value, GenericBQFilter(eq=fv)) + + return billing_filter + class BillingTotalCostRecord(SMBase): """Return class for the Billing Total Cost record""" @@ -192,14 +255,22 @@ class BillingTotalCostRecord(SMBase): topic: str | None gcp_project: str | None cost_category: str | None - sku: str | None + sku: str | dict | None + invoice_month: str | None ar_guid: str | None + # extended columns dataset: str | None batch_id: str | None sequencing_type: str | None stage: str | None sequencing_group: str | None + compute_category: str | None + cromwell_sub_workflow_name: str | None + cromwell_workflow_id: str | None + goog_pipelines_worker: str | None + wdl_task_name: str | None + namespace: str | None cost: float currency: str | None @@ -213,12 +284,19 @@ def from_json(record): gcp_project=record.get('gcp_project'), cost_category=record.get('cost_category'), sku=record.get('sku'), + invoice_month=record.get('invoice_month'), ar_guid=record.get('ar_guid'), dataset=record.get('dataset'), batch_id=record.get('batch_id'), sequencing_type=record.get('sequencing_type'), stage=record.get('stage'), sequencing_group=record.get('sequencing_group'), + compute_category=record.get('compute_category'), + cromwell_sub_workflow_name=record.get('cromwell_sub_workflow_name'), + cromwell_workflow_id=record.get('cromwell_workflow_id'), + goog_pipelines_worker=record.get('goog_pipelines_worker'), + wdl_task_name=record.get('wdl_task_name'), + namespace=record.get('namespace'), cost=record.get('cost'), currency=record.get('currency'), ) @@ -256,6 +334,7 @@ class BillingCostBudgetRecord(SMBase): storage_daily: float | None details: list[BillingCostDetailsRecord] | None budget_spent: float | None + budget: float | None last_loaded_day: str | None @@ -274,5 +353,14 @@ def from_json(record): BillingCostDetailsRecord.from_json(row) for row in record.get('details') ], budget_spent=record.get('budget_spent'), + budget=record.get('budget'), last_loaded_day=record.get('last_loaded_day'), ) + + +class BillingHailBatchCostRecord(SMBase): + """Return class for the Billing Cost by batch_id/ar_guid""" + + ar_guid: str | None + batch_ids: list[str] | None + costs: list[dict] | None diff --git a/models/models/participant.py b/models/models/participant.py index ab843e343..de0dc6ef5 100644 --- a/models/models/participant.py +++ b/models/models/participant.py @@ -1,8 +1,8 @@ import json -from db.python.utils import ProjectId from models.base import OpenApiGenNoneType, SMBase from models.models.family import FamilySimple, FamilySimpleInternal +from models.models.project import ProjectId from models.models.sample import ( NestedSample, NestedSampleInternal, @@ -22,6 +22,8 @@ class ParticipantInternal(SMBase): karyotype: str | None = None meta: dict + audit_log_id: int | None = None + @classmethod def from_db(cls, data: dict): """Convert from db keys, mainly converting parsing meta""" diff --git a/models/models/project.py b/models/models/project.py index 4372a086d..9ca19542f 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -3,11 +3,13 @@ from models.base import SMBase +ProjectId = int + class Project(SMBase): """Row for project in 'project' table""" - id: Optional[int] = None + id: Optional[ProjectId] = None name: Optional[str] = None dataset: Optional[str] = None meta: Optional[dict] = None diff --git a/models/models/sample.py b/models/models/sample.py index 0c2e5271d..5f183ff2e 100644 --- a/models/models/sample.py +++ b/models/models/sample.py @@ -1,6 +1,6 @@ import json -from models.base import OpenApiGenNoneType, SMBase +from models.base import OpenApiGenNoneType, SMBase, parse_sql_bool from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal from models.models.sequencing_group import ( NestedSequencingGroup, @@ -31,11 +31,8 @@ def from_db(d: dict): _id = d.pop('id', None) type_ = d.pop('type', None) meta = d.pop('meta', None) - active = d.pop('active', None) - if active is not None: - active = bool( - ord(active) if isinstance(active, (str, bytes, bytearray)) else active - ) + active = parse_sql_bool(d.pop('active', None)) + if meta: if isinstance(meta, bytes): meta = meta.decode() diff --git a/models/models/search.py b/models/models/search.py index cceb68a92..08b3a6d30 100644 --- a/models/models/search.py +++ b/models/models/search.py @@ -1,7 +1,7 @@ -from db.python.utils import ProjectId from models.base import SMBase from models.enums.search import SearchResponseType from models.enums.web import MetaSearchEntityPrefix +from models.models.project import ProjectId class SearchResponseData(SMBase): diff --git a/models/models/sequencing_group.py b/models/models/sequencing_group.py index 382b26ee9..15df04b6d 100644 --- a/models/models/sequencing_group.py +++ b/models/models/sequencing_group.py @@ -37,7 +37,7 @@ class SequencingGroupInternal(SMBase): platform: str | None = None meta: dict[str, str] | None = None sample_id: int | None = None - external_ids: dict[str, str] | None = None + external_ids: dict[str, str] | None = {} archived: bool | None = None project: int | None = None diff --git a/openapi-templates/api_client.mustache b/openapi-templates/api_client.mustache index b083d401d..d78bf1909 100644 --- a/openapi-templates/api_client.mustache +++ b/openapi-templates/api_client.mustache @@ -11,6 +11,8 @@ import typing from urllib.parse import quote from urllib3.fields import RequestField +from cpg_utils.config import try_get_ar_guid + {{#tornado}} import tornado.gen {{/tornado}} @@ -67,6 +69,9 @@ class ApiClient(object): self.rest_client = rest.RESTClientObject(configuration) self.default_headers = {} + ar_guid = try_get_ar_guid() + if ar_guid: + self.default_headers['sm-ar-guid'] = ar_guid if header_name is not None: self.default_headers[header_name] = header_value self.cookie = cookie diff --git a/openapi-templates/configuration.mustache b/openapi-templates/configuration.mustache index b1aafc2ac..609dea623 100644 --- a/openapi-templates/configuration.mustache +++ b/openapi-templates/configuration.mustache @@ -24,10 +24,10 @@ JSON_SCHEMA_VALIDATION_KEYWORDS = { sm_url = getenv('SM_URL') if not sm_url: - env = getenv('SM_ENVIRONMENT', 'PRODUCTION') - if 'local' in env.lower(): + env = getenv('SM_ENVIRONMENT', 'PRODUCTION').lower() + if 'local' in env: sm_url = "http://localhost:8000" - elif 'dev' in env.lower(): + elif 'dev' in env: sm_url = 'https://sample-metadata-api-dev-mnrpw3mdza-ts.a.run.app' else: sm_url = 'https://sample-metadata-api-mnrpw3mdza-ts.a.run.app' @@ -192,7 +192,7 @@ conf = {{{packageName}}}.Configuration( ): """Constructor """ - env = getenv('SM_ENVIRONMENT', 'PRODUCTION') + self.env = getenv('SM_ENVIRONMENT', 'PRODUCTION').lower() self._base_path = host or sm_url """Default Base url @@ -491,6 +491,9 @@ conf = {{{packageName}}}.Configuration( :return: The Auth Settings information dict. """ + if self.env == 'local': + return {} + auth = { 'HTTPBearer': { 'type': 'bearer', diff --git a/scripts/create_md5s.py b/scripts/create_md5s.py new file mode 100644 index 000000000..f99f64906 --- /dev/null +++ b/scripts/create_md5s.py @@ -0,0 +1,69 @@ +import os + +import click +from cpg_utils.hail_batch import get_batch, get_config, copy_common_env +from google.cloud import storage + + +def create_md5s_for_files_in_directory(skip_filetypes: tuple[str, str], force_recreate: bool, gs_dir): + """Validate files with MD5s in the provided gs directory""" + b = get_batch(f'Create md5 checksums for files in {gs_dir}') + + if not gs_dir.startswith('gs://'): + raise ValueError(f'Expected GS directory, got: {gs_dir}') + + billing_project = get_config()['workflow']['gcp_billing_project'] + driver_image = get_config()['workflow']['driver_image'] + + bucket_name, *components = gs_dir[5:].split('/') + + client = storage.Client() + bucket = client.bucket(bucket_name, user_project=billing_project) + blobs = bucket.list_blobs(prefix='/'.join(components)) + files: set[str] = {f'gs://{bucket_name}/{blob.name}' for blob in blobs} + for filepath in files: + if filepath.endswith('.md5') or filepath.endswith(skip_filetypes): + continue + if f'{filepath}.md5' in files and not force_recreate: + print(f'{filepath}.md5 already exists, skipping') + continue + + print('Creating md5 for', filepath) + job = b.new_job(f'Create {os.path.basename(filepath)}.md5') + create_md5(job, filepath, billing_project, driver_image) + + b.run(wait=False) + + +def create_md5(job, file, billing_project, driver_image): + """ + Streams the file with gsutil and calculates the md5 checksum, + then uploads the checksum to the same path as filename.md5. + """ + copy_common_env(job) + job.image(driver_image) + md5 = f'{file}.md5' + job.command( + f"""\ + set -euxo pipefail + gcloud -q auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS + gsutil -u {billing_project} cat {file} | md5sum | cut -d " " -f1 > /tmp/uploaded.md5 + gsutil -u {billing_project} cp /tmp/uploaded.md5 {md5} + """ + ) + + return job + + +@click.command() +@click.option('--skip-filetypes', '-s', default=('.crai', '.tbi'), multiple=True) +@click.option('--force-recreate', '-f', is_flag=True, default=False) +@click.argument('gs_dir') +def main(skip_filetypes: tuple[str, str], force_recreate: bool, gs_dir: str): + """Scans the directory for files and creates md5 checksums for them.""" + create_md5s_for_files_in_directory(skip_filetypes, force_recreate, gs_dir=gs_dir) + + +if __name__ == '__main__': + # pylint: disable=no-value-for-parameter + main() diff --git a/scripts/create_test_subset.py b/scripts/create_test_subset.py index cc1adf52b..14f8ccd00 100755 --- a/scripts/create_test_subset.py +++ b/scripts/create_test_subset.py @@ -276,7 +276,7 @@ def transfer_samples_sgs_assays( type=sample_type or None, meta=(copy_files_in_dict(s['meta'], project) or {}), participant_id=existing_pid, - sequencing_groups=upsert_sequencing_groups(s, existing_data), + sequencing_groups=upsert_sequencing_groups(s, existing_data, project), id=existing_sid, ) @@ -289,7 +289,7 @@ def transfer_samples_sgs_assays( def upsert_sequencing_groups( - sample: dict, existing_data: dict + sample: dict, existing_data: dict, project: str ) -> list[SequencingGroupUpsert]: """Create SG Upsert Objects for a sample""" sgs_to_upsert: list[SequencingGroupUpsert] = [] @@ -306,7 +306,7 @@ def upsert_sequencing_groups( technology=sg.get('technology'), type=sg.get('type'), assays=upsert_assays( - sg, existing_sgid, existing_data, sample.get('externalId') + sg, existing_sgid, existing_data, sample.get('externalId'), project ), ) sgs_to_upsert.append(sg_upsert) @@ -315,7 +315,11 @@ def upsert_sequencing_groups( def upsert_assays( - sg: dict, existing_sgid: str | None, existing_data: dict, sample_external_id + sg: dict, + existing_sgid: str | None, + existing_data: dict, + sample_external_id, + project: str, ) -> list[AssayUpsert]: """Create Assay Upsert Objects for a sequencing group""" print(sg) @@ -325,17 +329,14 @@ def upsert_assays( # Check if assay exists if existing_sgid: _existing_assay = get_existing_assay( - existing_data, - sample_external_id, - existing_sgid, - assay + existing_data, sample_external_id, existing_sgid, assay ) existing_assay_id = _existing_assay.get('id') if _existing_assay else None assay_upsert = AssayUpsert( type=assay.get('type'), id=existing_assay_id, external_ids=assay.get('externalIds') or {}, - meta=assay.get('meta'), + meta=copy_files_in_dict(assay.get('meta'), project), ) assays_to_upsert.append(assay_upsert) diff --git a/scripts/parse_existing_cohort.py b/scripts/parse_existing_cohort.py index 34d64f370..2603a8d60 100644 --- a/scripts/parse_existing_cohort.py +++ b/scripts/parse_existing_cohort.py @@ -111,6 +111,7 @@ def __init__( batch_number, include_participant_column, allow_missing_files, + sequencing_type, ): if include_participant_column: participant_column = Columns.PARTICIPANT_COLUMN @@ -131,6 +132,7 @@ def __init__( assay_meta_map=Columns.sequence_meta_map(), batch_number=batch_number, allow_extra_files_in_search_path=True, + default_sequencing_type=sequencing_type, ) def _get_dict_reader(self, file_pointer, delimiter: str): @@ -210,6 +212,11 @@ def get_existing_external_sequence_ids(self, participant_map: dict[str, dict]): '--project', help='The metamist project to import manifest into', ) +@click.option( + '--sequencing-type', + type=click.Choice(['genome', 'exome']), + help='Sequencing type: genome or exome', +) @click.option('--search-location', 'search_locations', multiple=True) @click.option( '--confirm', is_flag=True, help='Confirm with user input before updating server' @@ -236,6 +243,7 @@ async def main( dry_run=False, include_participant_column=False, allow_missing_files=False, + sequencing_type: str = 'genome', ): """Run script from CLI arguments""" @@ -245,6 +253,7 @@ async def main( batch_number=batch_number, include_participant_column=include_participant_column, allow_missing_files=allow_missing_files, + sequencing_type=sequencing_type, ) for manifest_path in manifests: diff --git a/scripts/parse_sample_file_map.py b/scripts/parse_sample_file_map.py index df139d7d1..ccc5059a3 100755 --- a/scripts/parse_sample_file_map.py +++ b/scripts/parse_sample_file_map.py @@ -35,7 +35,7 @@ help='The metamist project to import manifest into', ) @click.option('--default-sample-type', default='blood') -@click.option('--default-sequence-type', default='wgs') +@click.option('--default-sequencing-type', default='wgs') @click.option('--default-sequence-technology', default='short-read') @click.option( '--confirm', is_flag=True, help='Confirm with user input before updating server' diff --git a/setup.py b/setup.py index fc47d5a14..0b812b931 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( name=PKG, # This tag is automatically updated by bump2version - version='6.5.0', + version='6.6.2', description='Python API for interacting with the Sample API system', long_description=readme, long_description_content_type='text/markdown', diff --git a/test/data/generate_data.py b/test/data/generate_data.py index 329708098..b04f270bc 100755 --- a/test/data/generate_data.py +++ b/test/data/generate_data.py @@ -49,6 +49,7 @@ async def main(ped_path=default_ped_location, project='greek-myth'): papi = ProjectApi() sapi = SampleApi() + aapi = AnalysisApi() enum_resp: dict[str, dict[str, list[str]]] = await query_async(QUERY_ENUMS) # analysis_types = enum_resp['enum']['analysisType'] @@ -221,11 +222,42 @@ def generate_random_number_within_distribution(): ) ) - aapi = AnalysisApi() for ans in chunk(analyses_to_insert, 50): print(f'Inserting {len(ans)} analysis entries') await asyncio.gather(*[aapi.create_analysis_async(project, a) for a in ans]) + # create some fake analysis-runner entries + ar_entries = 20 + print(f'Inserting {ar_entries} analysis-runner entries') + await asyncio.gather( + *( + aapi.create_analysis_async( + project, + Analysis( + sequencing_group_ids=[], + type='analysis-runner', + status=AnalysisStatus('unknown'), + output='gs://cpg-fake-bucket/output', + meta={ + 'timestamp': f'2022-08-{i+1}T10:00:00.0000+00:00', + 'accessLevel': 'standard', + 'repo': 'sample-metadata', + 'commit': '7234c13855cc15b3471d340757ce87e7441abeb9', + 'script': 'python3 -m