diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index de3606668..c85699e4a 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -47,6 +47,7 @@ class LoaderKeys(enum.Enum): PROJECTS_FOR_IDS = 'projects_for_id' AUDIT_LOGS_BY_IDS = 'audit_logs_by_ids' + AUDIT_LOGS_BY_ANALYSIS_IDS = 'audit_logs_by_analysis_ids' ANALYSES_FOR_SEQUENCING_GROUPS = 'analyses_for_sequencing_groups' @@ -185,6 +186,18 @@ async def load_audit_logs_by_ids( return [logs_by_id.get(a) for a in audit_log_ids] +@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS) +async def load_audit_logs_by_analysis_ids( + analysis_ids: list[int], connection +) -> list[list[AuditLogInternal]]: + """ + DataLoader: get_audit_logs_by_analysis_ids + """ + alayer = AnalysisLayer(connection) + logs = await alayer.get_audit_logs_by_analysis_ids(analysis_ids) + return [logs.get(a) or [] for a in analysis_ids] + + @connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list) async def load_assays_by_samples( connection, ids, filter: AssayFilter diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 2bb445ecf..d0ac1619f 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -214,6 +214,7 @@ class GraphQLAuditLog: timestamp: datetime.datetime ar_guid: str | None comment: str | None + meta: strawberry.scalars.JSON @staticmethod def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog': @@ -223,6 +224,7 @@ def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog': timestamp=audit_log.timestamp, ar_guid=audit_log.ar_guid, comment=audit_log.comment, + meta=audit_log.meta, ) @@ -237,7 +239,6 @@ class GraphQLAnalysis: timestamp_completed: datetime.datetime | None = None active: bool meta: strawberry.scalars.JSON - author: str @staticmethod def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis': @@ -249,7 +250,6 @@ def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis': timestamp_completed=internal.timestamp_completed, active=internal.active, meta=internal.meta, - author=internal.author, ) @strawberry.field @@ -266,6 +266,14 @@ async def project(self, info: Info, root: 'GraphQLAnalysis') -> GraphQLProject: project = await loader.load(root.project) return GraphQLProject.from_internal(project) + @strawberry.field + async def audit_logs( + self, info: Info, root: 'GraphQLAnalysis' + ) -> list[GraphQLAuditLog]: + loader = info.context[LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS] + audit_logs = await loader.load(root.id) + return [GraphQLAuditLog.from_internal(audit_log) for audit_log in audit_logs] + @strawberry.type class GraphQLFamily: @@ -393,7 +401,6 @@ class GraphQLSample: active: bool meta: strawberry.scalars.JSON type: str - author: str | None # keep as integers, because they're useful to reference in the fields below internal_id: strawberry.Private[int] @@ -408,7 +415,6 @@ def from_internal(sample: SampleInternal) -> 'GraphQLSample': active=sample.active, meta=sample.meta, type=sample.type, - author=sample.author, # internals internal_id=sample.id, participant_id=sample.participant_id, diff --git a/api/server.py b/api/server.py index 3460ec285..652aa4322 100644 --- a/api/server.py +++ b/api/server.py @@ -2,22 +2,21 @@ import time import traceback -from fastapi import FastAPI, Request, HTTPException, APIRouter +from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles -from fastapi.middleware.cors import CORSMiddleware from pydantic import ValidationError from starlette.responses import FileResponse -from db.python.connect import SMConnections -from db.python.tables.project import is_all_access -from db.python.utils import get_logger - from api import routes -from api.utils import get_openapi_schema_func -from api.utils.exceptions import determine_code_from_error from api.graphql.schema import MetamistGraphQLRouter # type: ignore from api.settings import PROFILE_REQUESTS, SKIP_DATABASE_CONNECTION +from api.utils import get_openapi_schema_func +from api.utils.exceptions import determine_code_from_error +from db.python.connect import SMConnections +from db.python.tables.project import is_all_access +from db.python.utils import get_logger # This tag is automatically updated by bump2version _VERSION = '6.6.2' @@ -141,11 +140,11 @@ async def exception_handler(request: Request, e: Exception): cors_middleware = middlewares[0] request_origin = request.headers.get('origin', '') - if cors_middleware and '*' in cors_middleware.options['allow_origins']: + if cors_middleware and '*' in cors_middleware.options['allow_origins']: # type: ignore response.headers['Access-Control-Allow-Origin'] = '*' elif ( cors_middleware - and request_origin in cors_middleware.options['allow_origins'] + and request_origin in cors_middleware.options['allow_origins'] # type: ignore ): response.headers['Access-Control-Allow-Origin'] = request_origin @@ -169,9 +168,10 @@ async def exception_handler(request: Request, e: Exception): if __name__ == '__main__': - import uvicorn import logging + import uvicorn + logging.getLogger('watchfiles').setLevel(logging.WARNING) logging.getLogger('watchfiles.main').setLevel(logging.WARNING) diff --git a/db/python/layers/analysis.py b/db/python/layers/analysis.py index 67fd18c13..f54c91eaa 100644 --- a/db/python/layers/analysis.py +++ b/db/python/layers/analysis.py @@ -13,6 +13,7 @@ from models.enums import AnalysisStatus from models.models import ( AnalysisInternal, + AuditLogInternal, ProportionalDateModel, ProportionalDateProjectModel, ProportionalDateTemporalMethod, @@ -530,6 +531,10 @@ async def get_sgs_added_by_day_by_es_indices( return by_day + async def get_audit_logs_by_analysis_ids(self, analysis_ids: list[int]) -> dict[int, list[AuditLogInternal]]: + """Get audit logs for analysis IDs""" + return await self.at.get_audit_log_for_analysis_ids(analysis_ids) + # CREATE / UPDATE async def create_analysis( diff --git a/db/python/tables/analysis.py b/db/python/tables/analysis.py index d5c5d5fb7..ee445345e 100644 --- a/db/python/tables/analysis.py +++ b/db/python/tables/analysis.py @@ -14,6 +14,7 @@ ) from models.enums import AnalysisStatus from models.models.analysis import AnalysisInternal +from models.models.audit_log import AuditLogInternal from models.models.project import ProjectId @@ -395,7 +396,7 @@ async def get_analyses_for_samples( self, sample_ids: list[int], analysis_type: str | None, - status: AnalysisStatus, + status: AnalysisStatus | None, ) -> tuple[set[ProjectId], list[AnalysisInternal]]: """ Get relevant analyses for a sample, optional type / status filters @@ -593,3 +594,11 @@ async def get_sg_add_to_project_es_index( rows = await self.connection.fetch_all(_query, {'sg_ids': sg_ids}) return {r['sg_id']: r['timestamp_completed'].date() for r in rows} + + async def get_audit_log_for_analysis_ids( + self, analysis_ids: list[int] + ) -> dict[int, list[AuditLogInternal]]: + """ + Get audit logs for analysis IDs + """ + return await self.get_all_audit_logs_for_table('analysis', analysis_ids) diff --git a/db/python/tables/base.py b/db/python/tables/base.py index 82815d10b..e29353de8 100644 --- a/db/python/tables/base.py +++ b/db/python/tables/base.py @@ -1,7 +1,10 @@ +from collections import defaultdict + import databases from db.python.connect import Connection from db.python.utils import InternalError +from models.models.audit_log import AuditLogInternal class DbBase: @@ -40,3 +43,34 @@ def escape_like_term(query: str): Escape meaningful keys when using LIKE with a user supplied input """ return query.replace('%', '\\%').replace('_', '\\_') + + async def get_all_audit_logs_for_table( + self, table: str, ids: list[int], id_field='id' + ) -> dict[int, list[AuditLogInternal]]: + """ + Get all audit logs for values from a table + """ + _query = f""" + SELECT + t.{id_field} as table_id, + al.id as id, + al.author as author, + al.on_behalf_of as on_behalf_of, + al.timestamp as timestamp, + al.ar_guid as ar_guid, + al.comment as comment, + al.auth_project as auth_project, + al.meta as meta + FROM {table} FOR SYSTEM_TIME ALL t + INNER JOIN audit_log al + ON al.id = t.audit_log_id + WHERE t.{id_field} in :ids + """.strip() + rows = await self.connection.fetch_all(_query, {'ids': ids}) + by_id = defaultdict(list) + for r in rows: + row = dict(r) + id_value = row.pop('table_id') + by_id[id_value].append(AuditLogInternal.from_db(row)) + + return by_id diff --git a/models/models/audit_log.py b/models/models/audit_log.py index 3a2f2e69c..20969dd50 100644 --- a/models/models/audit_log.py +++ b/models/models/audit_log.py @@ -1,4 +1,5 @@ import datetime +import json from models.base import SMBase from models.models.project import ProjectId @@ -18,8 +19,13 @@ class AuditLogInternal(SMBase): on_behalf_of: str | None ar_guid: str | None comment: str | None + meta: dict | None @staticmethod def from_db(d: dict): """Take DB mapping object, and return SampleSequencing""" - return AuditLogInternal(**d) + meta = {} + if 'meta' in d: + meta = json.loads(d.pop('meta')) + + return AuditLogInternal(meta=meta, **d) diff --git a/scripts/create_md5s.py b/scripts/create_md5s.py index 488ec8297..f99f64906 100644 --- a/scripts/create_md5s.py +++ b/scripts/create_md5s.py @@ -12,24 +12,25 @@ def create_md5s_for_files_in_directory(skip_filetypes: tuple[str, str], force_re if not gs_dir.startswith('gs://'): raise ValueError(f'Expected GS directory, got: {gs_dir}') - billing_project = get_config()['hail']['billing_project'] + billing_project = get_config()['workflow']['gcp_billing_project'] driver_image = get_config()['workflow']['driver_image'] bucket_name, *components = gs_dir[5:].split('/') client = storage.Client() - blobs = client.list_blobs(bucket_name, prefix='/'.join(components)) + bucket = client.bucket(bucket_name, user_project=billing_project) + blobs = bucket.list_blobs(prefix='/'.join(components)) files: set[str] = {f'gs://{bucket_name}/{blob.name}' for blob in blobs} - for obj in files: - if obj.endswith('.md5') or obj.endswith(skip_filetypes): + for filepath in files: + if filepath.endswith('.md5') or filepath.endswith(skip_filetypes): continue - if f'{obj}.md5' in files and not force_recreate: - print(f'{obj}.md5 already exists, skipping') + if f'{filepath}.md5' in files and not force_recreate: + print(f'{filepath}.md5 already exists, skipping') continue - print('Creating md5 for', obj) - job = b.new_job(f'Create {os.path.basename(obj)}.md5') - create_md5(job, obj, billing_project, driver_image) + print('Creating md5 for', filepath) + job = b.new_job(f'Create {os.path.basename(filepath)}.md5') + create_md5(job, filepath, billing_project, driver_image) b.run(wait=False) @@ -46,7 +47,7 @@ def create_md5(job, file, billing_project, driver_image): f"""\ set -euxo pipefail gcloud -q auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS - gsutil cat {file} | md5sum | cut -d " " -f1 > /tmp/uploaded.md5 + gsutil -u {billing_project} cat {file} | md5sum | cut -d " " -f1 > /tmp/uploaded.md5 gsutil -u {billing_project} cp /tmp/uploaded.md5 {md5} """ ) diff --git a/test/data/generate_data.py b/test/data/generate_data.py index 329708098..b04f270bc 100755 --- a/test/data/generate_data.py +++ b/test/data/generate_data.py @@ -49,6 +49,7 @@ async def main(ped_path=default_ped_location, project='greek-myth'): papi = ProjectApi() sapi = SampleApi() + aapi = AnalysisApi() enum_resp: dict[str, dict[str, list[str]]] = await query_async(QUERY_ENUMS) # analysis_types = enum_resp['enum']['analysisType'] @@ -221,11 +222,42 @@ def generate_random_number_within_distribution(): ) ) - aapi = AnalysisApi() for ans in chunk(analyses_to_insert, 50): print(f'Inserting {len(ans)} analysis entries') await asyncio.gather(*[aapi.create_analysis_async(project, a) for a in ans]) + # create some fake analysis-runner entries + ar_entries = 20 + print(f'Inserting {ar_entries} analysis-runner entries') + await asyncio.gather( + *( + aapi.create_analysis_async( + project, + Analysis( + sequencing_group_ids=[], + type='analysis-runner', + status=AnalysisStatus('unknown'), + output='gs://cpg-fake-bucket/output', + meta={ + 'timestamp': f'2022-08-{i+1}T10:00:00.0000+00:00', + 'accessLevel': 'standard', + 'repo': 'sample-metadata', + 'commit': '7234c13855cc15b3471d340757ce87e7441abeb9', + 'script': 'python3 -m