Skip to content

Commit

Permalink
Merge branch 'dev' into populate-tob-test-with-nagim
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-harper committed Jan 17, 2024
2 parents 6c07839 + fcf07f2 commit a3839ad
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 45 deletions.
13 changes: 13 additions & 0 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class LoaderKeys(enum.Enum):
PROJECTS_FOR_IDS = 'projects_for_id'

AUDIT_LOGS_BY_IDS = 'audit_logs_by_ids'
AUDIT_LOGS_BY_ANALYSIS_IDS = 'audit_logs_by_analysis_ids'

ANALYSES_FOR_SEQUENCING_GROUPS = 'analyses_for_sequencing_groups'

Expand Down Expand Up @@ -185,6 +186,18 @@ async def load_audit_logs_by_ids(
return [logs_by_id.get(a) for a in audit_log_ids]


@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS)
async def load_audit_logs_by_analysis_ids(
analysis_ids: list[int], connection
) -> list[list[AuditLogInternal]]:
"""
DataLoader: get_audit_logs_by_analysis_ids
"""
alayer = AnalysisLayer(connection)
logs = await alayer.get_audit_logs_by_analysis_ids(analysis_ids)
return [logs.get(a) or [] for a in analysis_ids]


@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
Expand Down
14 changes: 10 additions & 4 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class GraphQLAuditLog:
timestamp: datetime.datetime
ar_guid: str | None
comment: str | None
meta: strawberry.scalars.JSON

@staticmethod
def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog':
Expand All @@ -223,6 +224,7 @@ def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog':
timestamp=audit_log.timestamp,
ar_guid=audit_log.ar_guid,
comment=audit_log.comment,
meta=audit_log.meta,
)


Expand All @@ -237,7 +239,6 @@ class GraphQLAnalysis:
timestamp_completed: datetime.datetime | None = None
active: bool
meta: strawberry.scalars.JSON
author: str

@staticmethod
def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis':
Expand All @@ -249,7 +250,6 @@ def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis':
timestamp_completed=internal.timestamp_completed,
active=internal.active,
meta=internal.meta,
author=internal.author,
)

@strawberry.field
Expand All @@ -266,6 +266,14 @@ async def project(self, info: Info, root: 'GraphQLAnalysis') -> GraphQLProject:
project = await loader.load(root.project)
return GraphQLProject.from_internal(project)

@strawberry.field
async def audit_logs(
self, info: Info, root: 'GraphQLAnalysis'
) -> list[GraphQLAuditLog]:
loader = info.context[LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS]
audit_logs = await loader.load(root.id)
return [GraphQLAuditLog.from_internal(audit_log) for audit_log in audit_logs]


@strawberry.type
class GraphQLFamily:
Expand Down Expand Up @@ -393,7 +401,6 @@ class GraphQLSample:
active: bool
meta: strawberry.scalars.JSON
type: str
author: str | None

# keep as integers, because they're useful to reference in the fields below
internal_id: strawberry.Private[int]
Expand All @@ -408,7 +415,6 @@ def from_internal(sample: SampleInternal) -> 'GraphQLSample':
active=sample.active,
meta=sample.meta,
type=sample.type,
author=sample.author,
# internals
internal_id=sample.id,
participant_id=sample.participant_id,
Expand Down
22 changes: 11 additions & 11 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
import time
import traceback

from fastapi import FastAPI, Request, HTTPException, APIRouter
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import ValidationError
from starlette.responses import FileResponse

from db.python.connect import SMConnections
from db.python.tables.project import is_all_access
from db.python.utils import get_logger

from api import routes
from api.utils import get_openapi_schema_func
from api.utils.exceptions import determine_code_from_error
from api.graphql.schema import MetamistGraphQLRouter # type: ignore
from api.settings import PROFILE_REQUESTS, SKIP_DATABASE_CONNECTION
from api.utils import get_openapi_schema_func
from api.utils.exceptions import determine_code_from_error
from db.python.connect import SMConnections
from db.python.tables.project import is_all_access
from db.python.utils import get_logger

# This tag is automatically updated by bump2version
_VERSION = '6.6.2'
Expand Down Expand Up @@ -141,11 +140,11 @@ async def exception_handler(request: Request, e: Exception):
cors_middleware = middlewares[0]

request_origin = request.headers.get('origin', '')
if cors_middleware and '*' in cors_middleware.options['allow_origins']:
if cors_middleware and '*' in cors_middleware.options['allow_origins']: # type: ignore
response.headers['Access-Control-Allow-Origin'] = '*'
elif (
cors_middleware
and request_origin in cors_middleware.options['allow_origins']
and request_origin in cors_middleware.options['allow_origins'] # type: ignore
):
response.headers['Access-Control-Allow-Origin'] = request_origin

Expand All @@ -169,9 +168,10 @@ async def exception_handler(request: Request, e: Exception):


if __name__ == '__main__':
import uvicorn
import logging

import uvicorn

logging.getLogger('watchfiles').setLevel(logging.WARNING)
logging.getLogger('watchfiles.main').setLevel(logging.WARNING)

Expand Down
5 changes: 5 additions & 0 deletions db/python/layers/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from models.enums import AnalysisStatus
from models.models import (
AnalysisInternal,
AuditLogInternal,
ProportionalDateModel,
ProportionalDateProjectModel,
ProportionalDateTemporalMethod,
Expand Down Expand Up @@ -530,6 +531,10 @@ async def get_sgs_added_by_day_by_es_indices(

return by_day

async def get_audit_logs_by_analysis_ids(self, analysis_ids: list[int]) -> dict[int, list[AuditLogInternal]]:
"""Get audit logs for analysis IDs"""
return await self.at.get_audit_log_for_analysis_ids(analysis_ids)

# CREATE / UPDATE

async def create_analysis(
Expand Down
11 changes: 10 additions & 1 deletion db/python/tables/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from models.enums import AnalysisStatus
from models.models.analysis import AnalysisInternal
from models.models.audit_log import AuditLogInternal
from models.models.project import ProjectId


Expand Down Expand Up @@ -395,7 +396,7 @@ async def get_analyses_for_samples(
self,
sample_ids: list[int],
analysis_type: str | None,
status: AnalysisStatus,
status: AnalysisStatus | None,
) -> tuple[set[ProjectId], list[AnalysisInternal]]:
"""
Get relevant analyses for a sample, optional type / status filters
Expand Down Expand Up @@ -593,3 +594,11 @@ async def get_sg_add_to_project_es_index(

rows = await self.connection.fetch_all(_query, {'sg_ids': sg_ids})
return {r['sg_id']: r['timestamp_completed'].date() for r in rows}

async def get_audit_log_for_analysis_ids(
self, analysis_ids: list[int]
) -> dict[int, list[AuditLogInternal]]:
"""
Get audit logs for analysis IDs
"""
return await self.get_all_audit_logs_for_table('analysis', analysis_ids)
34 changes: 34 additions & 0 deletions db/python/tables/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections import defaultdict

import databases

from db.python.connect import Connection
from db.python.utils import InternalError
from models.models.audit_log import AuditLogInternal


class DbBase:
Expand Down Expand Up @@ -40,3 +43,34 @@ def escape_like_term(query: str):
Escape meaningful keys when using LIKE with a user supplied input
"""
return query.replace('%', '\\%').replace('_', '\\_')

async def get_all_audit_logs_for_table(
self, table: str, ids: list[int], id_field='id'
) -> dict[int, list[AuditLogInternal]]:
"""
Get all audit logs for values from a table
"""
_query = f"""
SELECT
t.{id_field} as table_id,
al.id as id,
al.author as author,
al.on_behalf_of as on_behalf_of,
al.timestamp as timestamp,
al.ar_guid as ar_guid,
al.comment as comment,
al.auth_project as auth_project,
al.meta as meta
FROM {table} FOR SYSTEM_TIME ALL t
INNER JOIN audit_log al
ON al.id = t.audit_log_id
WHERE t.{id_field} in :ids
""".strip()
rows = await self.connection.fetch_all(_query, {'ids': ids})
by_id = defaultdict(list)
for r in rows:
row = dict(r)
id_value = row.pop('table_id')
by_id[id_value].append(AuditLogInternal.from_db(row))

return by_id
8 changes: 7 additions & 1 deletion models/models/audit_log.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import json

from models.base import SMBase
from models.models.project import ProjectId
Expand All @@ -18,8 +19,13 @@ class AuditLogInternal(SMBase):
on_behalf_of: str | None
ar_guid: str | None
comment: str | None
meta: dict | None

@staticmethod
def from_db(d: dict):
"""Take DB mapping object, and return SampleSequencing"""
return AuditLogInternal(**d)
meta = {}
if 'meta' in d:
meta = json.loads(d.pop('meta'))

return AuditLogInternal(meta=meta, **d)
21 changes: 11 additions & 10 deletions scripts/create_md5s.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@ def create_md5s_for_files_in_directory(skip_filetypes: tuple[str, str], force_re
if not gs_dir.startswith('gs://'):
raise ValueError(f'Expected GS directory, got: {gs_dir}')

billing_project = get_config()['hail']['billing_project']
billing_project = get_config()['workflow']['gcp_billing_project']
driver_image = get_config()['workflow']['driver_image']

bucket_name, *components = gs_dir[5:].split('/')

client = storage.Client()
blobs = client.list_blobs(bucket_name, prefix='/'.join(components))
bucket = client.bucket(bucket_name, user_project=billing_project)
blobs = bucket.list_blobs(prefix='/'.join(components))
files: set[str] = {f'gs://{bucket_name}/{blob.name}' for blob in blobs}
for obj in files:
if obj.endswith('.md5') or obj.endswith(skip_filetypes):
for filepath in files:
if filepath.endswith('.md5') or filepath.endswith(skip_filetypes):
continue
if f'{obj}.md5' in files and not force_recreate:
print(f'{obj}.md5 already exists, skipping')
if f'{filepath}.md5' in files and not force_recreate:
print(f'{filepath}.md5 already exists, skipping')
continue

print('Creating md5 for', obj)
job = b.new_job(f'Create {os.path.basename(obj)}.md5')
create_md5(job, obj, billing_project, driver_image)
print('Creating md5 for', filepath)
job = b.new_job(f'Create {os.path.basename(filepath)}.md5')
create_md5(job, filepath, billing_project, driver_image)

b.run(wait=False)

Expand All @@ -46,7 +47,7 @@ def create_md5(job, file, billing_project, driver_image):
f"""\
set -euxo pipefail
gcloud -q auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS
gsutil cat {file} | md5sum | cut -d " " -f1 > /tmp/uploaded.md5
gsutil -u {billing_project} cat {file} | md5sum | cut -d " " -f1 > /tmp/uploaded.md5
gsutil -u {billing_project} cp /tmp/uploaded.md5 {md5}
"""
)
Expand Down
34 changes: 33 additions & 1 deletion test/data/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def main(ped_path=default_ped_location, project='greek-myth'):

papi = ProjectApi()
sapi = SampleApi()
aapi = AnalysisApi()

enum_resp: dict[str, dict[str, list[str]]] = await query_async(QUERY_ENUMS)
# analysis_types = enum_resp['enum']['analysisType']
Expand Down Expand Up @@ -221,11 +222,42 @@ def generate_random_number_within_distribution():
)
)

aapi = AnalysisApi()
for ans in chunk(analyses_to_insert, 50):
print(f'Inserting {len(ans)} analysis entries')
await asyncio.gather(*[aapi.create_analysis_async(project, a) for a in ans])

# create some fake analysis-runner entries
ar_entries = 20
print(f'Inserting {ar_entries} analysis-runner entries')
await asyncio.gather(
*(
aapi.create_analysis_async(
project,
Analysis(
sequencing_group_ids=[],
type='analysis-runner',
status=AnalysisStatus('unknown'),
output='gs://cpg-fake-bucket/output',
meta={
'timestamp': f'2022-08-{i+1}T10:00:00.0000+00:00',
'accessLevel': 'standard',
'repo': 'sample-metadata',
'commit': '7234c13855cc15b3471d340757ce87e7441abeb9',
'script': 'python3 -m <script>',
'description': f'Run {i+1}',
'driverImage': 'australia-southeast1-docker.pkg.dev/ar/images/image:tag',
'configPath': 'gs://cpg-config/<guid>.toml',
'cwd': None,
'hailVersion': '0.2.126-cac7ac4164b2',
'batch_url': 'https://batch.hail.populationgenomics.org.au/batches/0',
'source': 'analysis-runner',
},
),
)
for i in range(ar_entries)
)
)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
Expand Down
1 change: 0 additions & 1 deletion web/src/pages/family/FamilyView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ query FamilyInfo($family_id: Int!) {
id
samples {
active
author
externalId
id
meta
Expand Down
Loading

0 comments on commit a3839ad

Please sign in to comment.