diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 81fe6b44e..de3606668 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -10,10 +10,10 @@ from strawberry.dataloader import DataLoader from api.utils import get_projectless_db_connection, group_by -from db.python.connect import NotFoundError from db.python.layers import ( AnalysisLayer, AssayLayer, + AuditLogLayer, FamilyLayer, ParticipantLayer, SampleLayer, @@ -24,16 +24,18 @@ from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter -from db.python.utils import GenericFilter, ProjectId +from db.python.utils import GenericFilter, NotFoundError from models.models import ( AnalysisInternal, AssayInternal, FamilyInternal, ParticipantInternal, Project, + ProjectId, SampleInternal, SequencingGroupInternal, ) +from models.models.audit_log import AuditLogInternal class LoaderKeys(enum.Enum): @@ -44,6 +46,8 @@ class LoaderKeys(enum.Enum): PROJECTS_FOR_IDS = 'projects_for_id' + AUDIT_LOGS_BY_IDS = 'audit_logs_by_ids' + ANALYSES_FOR_SEQUENCING_GROUPS = 'analyses_for_sequencing_groups' ASSAYS_FOR_SAMPLES = 'sequences_for_samples' @@ -168,6 +172,19 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]: return connected_data_loader_caller +@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS) +async def load_audit_logs_by_ids( + audit_log_ids: list[int], connection +) -> list[AuditLogInternal | None]: + """ + DataLoader: get_audit_logs_by_ids + """ + alayer = AuditLogLayer(connection) + logs = await alayer.get_for_ids(audit_log_ids) + logs_by_id = {log.id: log for log in logs} + return [logs_by_id.get(a) for a in audit_log_ids] + + @connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list) async def load_assays_by_samples( connection, ids, filter: AssayFilter @@ -332,7 +349,7 @@ async def load_projects_for_ids(project_ids: list[int], connection) -> list[Proj """ Get projects by IDs """ - pttable = ProjectPermissionsTable(connection.connection) + pttable = ProjectPermissionsTable(connection) projects = await pttable.get_and_check_access_to_projects_for_ids( user=connection.user, project_ids=project_ids, readonly=True ) diff --git a/api/graphql/schema.py b/api/graphql/schema.py index f8d70dcbd..730c2d61a 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -31,6 +31,7 @@ from models.models import ( AnalysisInternal, AssayInternal, + AuditLogInternal, FamilyInternal, ParticipantInternal, Project, @@ -204,6 +205,27 @@ async def analyses( return [GraphQLAnalysis.from_internal(a) for a in internal_analysis] +@strawberry.type +class GraphQLAuditLog: + """AuditLog GraphQL model""" + + id: int + author: str + timestamp: datetime.datetime + ar_guid: str | None + comment: str | None + + @staticmethod + def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog': + return GraphQLAuditLog( + id=audit_log.id, + author=audit_log.author, + timestamp=audit_log.timestamp, + ar_guid=audit_log.ar_guid, + comment=audit_log.comment, + ) + + @strawberry.type class GraphQLAnalysis: """Analysis GraphQL model""" @@ -297,6 +319,7 @@ class GraphQLParticipant: karyotype: str | None project_id: strawberry.Private[int] + audit_log_id: strawberry.Private[int | None] @staticmethod def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant': @@ -308,6 +331,7 @@ def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant': reported_gender=internal.reported_gender, karyotype=internal.karyotype, project_id=internal.project, + audit_log_id=internal.audit_log_id, ) @strawberry.field @@ -349,6 +373,16 @@ async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProjec project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) + @strawberry.field + async def audit_log( + self, info: Info, root: 'GraphQLParticipant' + ) -> GraphQLAuditLog | None: + if root.audit_log_id is None: + return None + loader = info.context[LoaderKeys.AUDIT_LOGS_BY_IDS] + audit_log = await loader.load(root.audit_log_id) + return GraphQLAuditLog.from_internal(audit_log) + @strawberry.type class GraphQLSample: @@ -367,7 +401,7 @@ class GraphQLSample: project_id: strawberry.Private[int] @staticmethod - def from_internal(sample: SampleInternal): + def from_internal(sample: SampleInternal) -> 'GraphQLSample': return GraphQLSample( id=sample_id_format(sample.id), external_id=sample.external_id, @@ -491,7 +525,7 @@ async def analyses( loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS] project_id_map = {} if project: - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) project_ids = project.all_values() projects = await ptable.get_and_check_access_to_projects_for_names( user=connection.author, project_names=project_ids, readonly=True @@ -564,7 +598,7 @@ def enum(self, info: Info) -> GraphQLEnum: @strawberry.field() async def project(self, info: Info, name: str) -> GraphQLProject: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) project = await ptable.get_and_check_access_to_project_for_name( user=connection.author, project_name=name, readonly=True ) @@ -583,7 +617,7 @@ async def sample( active: GraphQLFilter[bool] | None = None, ) -> list[GraphQLSample]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) slayer = SampleLayer(connection) if not id and not project: @@ -631,7 +665,7 @@ async def sequencing_groups( ) -> list[GraphQLSequencingGroup]: connection = info.context['connection'] sglayer = SequencingGroupLayer(connection) - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) if not (project or sample_id or id): raise ValueError('Must filter by project, sample or id') @@ -685,7 +719,7 @@ async def family(self, info: Info, family_id: int) -> GraphQLFamily: @strawberry.field async def my_projects(self, info: Info) -> list[GraphQLProject]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) projects = await ptable.get_projects_accessible_by_user( connection.author, readonly=True ) diff --git a/api/routes/analysis.py b/api/routes/analysis.py index a4c8c6f95..6aaa86645 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -104,10 +104,12 @@ async def create_analysis( atable = AnalysisLayer(connection) + if analysis.author: + # special tracking here, if we can't catch it through the header + connection.on_behalf_of = analysis.author + analysis_id = await atable.create_analysis( analysis.to_internal(), - # analysis-runner: usage is tracked through `on_behalf_of` - author=analysis.author, ) return analysis_id @@ -226,7 +228,7 @@ async def query_analyses( if not query.projects: raise ValueError('Must specify "projects"') - pt = ProjectPermissionsTable(connection=connection.connection) + pt = ProjectPermissionsTable(connection) projects = await pt.get_and_check_access_to_projects_for_names( user=connection.author, project_names=query.projects, readonly=True ) @@ -249,7 +251,7 @@ async def get_analysis_runner_log( atable = AnalysisLayer(connection) project_ids = None if project_names: - pt = ProjectPermissionsTable(connection=connection.connection) + pt = ProjectPermissionsTable(connection) project_ids = await pt.get_project_ids_from_names_and_user( connection.author, project_names, readonly=True ) @@ -335,7 +337,7 @@ async def get_proportionate_map( } } """ - pt = ProjectPermissionsTable(connection=connection.connection) + pt = ProjectPermissionsTable(connection) project_ids = await pt.get_project_ids_from_names_and_user( connection.author, projects, readonly=True ) diff --git a/api/routes/assay.py b/api/routes/assay.py index d0b386ccb..3770d5a76 100644 --- a/api/routes/assay.py +++ b/api/routes/assay.py @@ -79,7 +79,7 @@ async def get_assays_by_criteria( ): """Get assays by criteria""" assay_layer = AssayLayer(connection) - pt = ProjectPermissionsTable(connection.connection) + pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if projects: diff --git a/api/routes/project.py b/api/routes/project.py index 5771dfc91..559938485 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -12,14 +12,14 @@ @router.get('/all', operation_id='getAllProjects', response_model=List[Project]) async def get_all_projects(connection=get_projectless_db_connection): """Get list of projects""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) return await ptable.get_all_projects(author=connection.author) @router.get('/', operation_id='getMyProjects', response_model=List[str]) async def get_my_projects(connection=get_projectless_db_connection): """Get projects I have access to""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) projects = await ptable.get_projects_accessible_by_user( author=connection.author, readonly=True ) @@ -36,7 +36,7 @@ async def create_project( """ Create a new project """ - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) pid = await ptable.create_project( project_name=name, dataset_name=dataset, @@ -56,7 +56,7 @@ async def create_project( @router.get('/seqr/all', operation_id='getSeqrProjects') async def get_seqr_projects(connection: Connection = get_projectless_db_connection): """Get SM projects that should sync to seqr""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) return await ptable.get_seqr_projects() @@ -67,7 +67,7 @@ async def update_project( connection: Connection = get_projectless_db_connection, ): """Update a project by project name""" - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) return await ptable.update_project( project_name=project, update=project_update_model, author=connection.author ) @@ -84,7 +84,7 @@ async def delete_project_data( Can optionally delete the project itself. Requires READ access + project-creator permissions """ - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) p_obj = await ptable.get_and_check_access_to_project_for_name( user=connection.author, project_name=project, readonly=False ) @@ -106,7 +106,7 @@ async def update_project_members( Update project members for specific read / write group. Not that this is protected by access to a specific access group """ - ptable = ProjectPermissionsTable(connection.connection) + ptable = ProjectPermissionsTable(connection) await ptable.set_group_members( group_name=ptable.get_project_group_name(project, readonly=readonly), members=members, diff --git a/api/routes/sample.py b/api/routes/sample.py index 1b4bf76d9..646a2dc7b 100644 --- a/api/routes/sample.py +++ b/api/routes/sample.py @@ -130,7 +130,7 @@ async def get_samples( """ st = SampleLayer(connection) - pt = ProjectPermissionsTable(connection.connection) + pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if project_ids: pids = await pt.get_project_ids_from_names_and_user( diff --git a/api/routes/web.py b/api/routes/web.py index 38255514b..789e174d0 100644 --- a/api/routes/web.py +++ b/api/routes/web.py @@ -77,7 +77,7 @@ async def search_by_keyword(keyword: str, connection=get_projectless_db_connecti that you are a part of (automatically). """ # raise ValueError("Test") - pt = ProjectPermissionsTable(connection.connection) + pt = ProjectPermissionsTable(connection) projects = await pt.get_projects_accessible_by_user( connection.author, readonly=True ) diff --git a/api/utils/db.py b/api/utils/db.py index 7a22944fc..8e23419e8 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -1,6 +1,5 @@ from os import getenv import logging -from typing import Optional from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials @@ -11,12 +10,12 @@ from api.utils.gcp import email_from_id_token from db.python.connect import SMConnections, Connection from db.python.gcp_connect import BqConnection, PubSubConnection - +from db.python.tables.project import ProjectPermissionsTable EXPECTED_AUDIENCE = getenv('SM_OAUTHAUDIENCE') -def get_jwt_from_request(request: Request) -> Optional[str]: +def get_jwt_from_request(request: Request) -> str | None: """ Get google JWT value, capture it like this instead of using x_goog_iap_jwt_assertion = Header(None) @@ -25,11 +24,22 @@ def get_jwt_from_request(request: Request) -> Optional[str]: return request.headers.get('x-goog-iap-jwt-assertion') +def get_ar_guid(request: Request) -> str | None: + """Get sm-ar-guid from the headers to provide with requests""" + return request.headers.get('sm-ar-guid') + + +def get_on_behalf_of(request: Request) -> str | None: + """ + Get sm-on-behalf-of if there are requests that were performed on behalf of + someone else (some automated process) + """ + return request.headers.get('sm-on-behalf-of') + + def authenticate( - token: Optional[HTTPAuthorizationCredentials] = Depends( - HTTPBearer(auto_error=False) - ), - x_goog_iap_jwt_assertion: Optional[str] = Depends(get_jwt_from_request), + token: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False)), + x_goog_iap_jwt_assertion: str | None = Depends(get_jwt_from_request), ) -> str: """ If a token (OR Google IAP auth jwt) is provided, @@ -55,26 +65,41 @@ def authenticate( async def dependable_get_write_project_connection( - project: str, author: str = Depends(authenticate) + project: str, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), + on_behalf_of: str | None = Depends(get_on_behalf_of), ) -> Connection: """FastAPI handler for getting connection WITH project""" - return await SMConnections.get_connection( - project_name=project, author=author, readonly=False + return await ProjectPermissionsTable.get_project_connection( + project_name=project, + author=author, + readonly=False, + ar_guid=ar_guid, + on_behalf_of=on_behalf_of, ) async def dependable_get_readonly_project_connection( - project: str, author: str = Depends(authenticate) + project: str, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), ) -> Connection: """FastAPI handler for getting connection WITH project""" - return await SMConnections.get_connection( - project_name=project, author=author, readonly=True + return await ProjectPermissionsTable.get_project_connection( + project_name=project, + author=author, + readonly=True, + on_behalf_of=None, + ar_guid=ar_guid, ) -async def dependable_get_connection(author: str = Depends(authenticate)): +async def dependable_get_connection( + author: str = Depends(authenticate), ar_guid: str = Depends(get_ar_guid) +): """FastAPI handler for getting connection withOUT project""" - return await SMConnections.get_connection_no_project(author) + return await SMConnections.get_connection_no_project(author, ar_guid=ar_guid) async def dependable_get_bq_connection(author: str = Depends(authenticate)): diff --git a/api/utils/exceptions.py b/api/utils/exceptions.py index c418ef6cd..ff24b6f6e 100644 --- a/api/utils/exceptions.py +++ b/api/utils/exceptions.py @@ -1,5 +1,4 @@ -from db.python.connect import NotFoundError -from db.python.utils import Forbidden +from db.python.utils import Forbidden, NotFoundError def determine_code_from_error(e): diff --git a/db/project.xml b/db/project.xml index 54099a8de..118a934ba 100644 --- a/db/project.xml +++ b/db/project.xml @@ -843,4 +843,268 @@ INSERT INTO `group` (name) VALUES ('project-creators'); INSERT INTO `group` (name) VALUES ('members-admin'); + + SET @@system_versioning_alter_history = 1; + + + + + + + + + + + + + + + + + + + + + + + + + ALTER TABLE `audit_log` ADD SYSTEM VERSIONING; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ALTER TABLE analysis CHANGE author author VARCHAR(255) NULL; + ALTER TABLE assay CHANGE author author VARCHAR(255) NULL; + ALTER TABLE assay_external_id CHANGE author author VARCHAR(255) NULL; + ALTER TABLE cohort CHANGE author author VARCHAR(255) NULL; + ALTER TABLE family CHANGE author author VARCHAR(255) NULL; + ALTER TABLE family_participant CHANGE author author VARCHAR(255) NULL; + ALTER TABLE group_member CHANGE author author VARCHAR(255) NULL DEFAULT NULL; + ALTER TABLE participant CHANGE author author VARCHAR(255) NULL; + ALTER TABLE participant_phenotypes CHANGE author author VARCHAR(255) NULL; + ALTER TABLE project CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sample CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sample_sequencing CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sample_sequencing_eid CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sequencing_group CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sequencing_group_assay CHANGE author author VARCHAR(255) NULL; + ALTER TABLE sequencing_group_external_id CHANGE author author VARCHAR(255) NULL; + diff --git a/db/python/connect.py b/db/python/connect.py index c72b69fcf..a800ac4c9 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -4,20 +4,22 @@ Code for connecting to Postgres database """ import abc +import asyncio import json import logging import os -from typing import Optional +from threading import Lock import databases from api.settings import LOG_DATABASE_QUERIES -from db.python.tables.project import ProjectPermissionsTable -from db.python.utils import InternalError, NoOpAenter, NotFoundError +from db.python.utils import InternalError logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) +audit_log_lock = Lock() + TABLES_ORDERED_BY_FK_DEPS = [ 'project', 'group', @@ -44,12 +46,44 @@ class Connection: def __init__( self, connection: databases.Database, - project: Optional[int], + project: int | None, author: str, + on_behalf_of: str | None, + readonly: bool, + ar_guid: str | None, ): self.connection: databases.Database = connection - self.project: Optional[int] = project + self.project: int | None = project self.author: str = author + self.on_behalf_of: str | None = on_behalf_of + self.readonly: bool = readonly + self.ar_guid: str | None = ar_guid + + self._audit_log_id: int | None = None + + async def audit_log_id(self): + """Get audit_log ID for write operations, cached per connection""" + if self.readonly: + raise InternalError( + 'Trying to get a audit_log ID, but not a write connection' + ) + + with audit_log_lock: + if not self._audit_log_id: + # pylint: disable=import-outside-toplevel + # make this import here, otherwise we'd have a circular import + from db.python.tables.audit_log import AuditLogTable + + at = AuditLogTable(self) + self._audit_log_id = await at.create_audit_log( + author=self.author, + on_behalf_of=self.on_behalf_of, + ar_guid=self.ar_guid, + comment=None, + project=self.project, + ) + + return self._audit_log_id def assert_requires_project(self): """Assert the project is set, or return an exception""" @@ -132,11 +166,7 @@ def get_connection_string(self): class SMConnections: """Contains useful functions for connecting to the database""" - # _connected = False - # _connections: Dict[str, databases.Database] = {} - # _admin_db: databases.Database = None - - _credentials: Optional[DatabaseConfiguration] = None + _credentials: DatabaseConfiguration | None = None @staticmethod def _get_config(): @@ -175,7 +205,10 @@ async def disconnect(): return False @staticmethod - async def _get_made_connection(): + async def get_made_connection(): + """ + Makes a new connection to the database on each call + """ credentials = SMConnections._get_config() if credentials is None: @@ -190,70 +223,21 @@ async def _get_made_connection(): return conn @staticmethod - async def get_connection(*, author: str, project_name: str, readonly: bool): - """Get a db connection from a project and user""" - # maybe it makes sense to perform permission checks here too - logger.debug(f'Authenticate connection to {project_name} with {author!r}') - - conn = await SMConnections._get_made_connection() - pt = ProjectPermissionsTable(connection=conn) - - project = await pt.get_and_check_access_to_project_for_name( - user=author, project_name=project_name, readonly=readonly - ) - - return Connection(connection=conn, author=author, project=project.id) - - @staticmethod - async def get_connection_no_project(author: str): + async def get_connection_no_project(author: str, ar_guid: str): """Get a db connection from a project and user""" # maybe it makes sense to perform permission checks here too logger.debug(f'Authenticate no-project connection with {author!r}') - conn = await SMConnections._get_made_connection() + conn = await SMConnections.get_made_connection() # we don't authenticate project-less connection, but rely on the # the endpoint to validate the resources - return Connection(connection=conn, author=author, project=None) - - -class DbBase: - """Base class for table subclasses""" - - @classmethod - async def from_project(cls, project, author, readonly: bool): - """Create the Db object from a project with user details""" - return cls( - connection=await SMConnections.get_connection( - project_name=project, author=author, readonly=readonly - ), + return Connection( + connection=conn, + author=author, + project=None, + on_behalf_of=None, + ar_guid=ar_guid, + readonly=False, ) - - def __init__(self, connection: Connection): - if connection is None: - raise InternalError( - f'No connection was provided to the table {self.__class__.__name__!r}' - ) - if not isinstance(connection, Connection): - raise InternalError( - f'Expected connection type Connection, received {type(connection)}, ' - f'did you mean to call self._connection?' - ) - - self._connection = connection - self.connection: databases.Database = connection.connection - self.author = connection.author - self.project = connection.project - - if self.author is None: - raise InternalError(f'Must provide author to {self.__class__.__name__}') - - # piped from the connection - - @staticmethod - def escape_like_term(query: str): - """ - Escape meaningful keys when using LIKE with a user supplied input - """ - return query.replace('%', '\\%').replace('_', '\\_') diff --git a/db/python/enum_tables/enums.py b/db/python/enum_tables/enums.py index 113daf3d6..1313b0cb4 100644 --- a/db/python/enum_tables/enums.py +++ b/db/python/enum_tables/enums.py @@ -4,7 +4,7 @@ from async_lru import alru_cache -from db.python.connect import DbBase +from db.python.tables.base import DbBase table_name_matcher = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') @@ -58,12 +58,14 @@ async def insert(self, value: str): Insert a new type """ _query = f""" - INSERT INTO {self._get_table_name()} (id, name) - VALUES (:name, :name) - ON DUPLICATE KEY UPDATE name = :name + INSERT INTO {self._get_table_name()} (id, name, audit_log_id) + VALUES (:name, :name, :audit_log_id) + ON DUPLICATE KEY UPDATE name = :name, audit_log_id = :audit_log_id """ - await self.connection.execute(_query, {'name': value.lower()}) + await self.connection.execute( + _query, {'name': value.lower(), 'audit_log_id': await self.audit_log_id()} + ) # clear the cache so results are up-to-date self.get.cache_clear() # pylint: disable=no-member return value diff --git a/db/python/layers/__init__.py b/db/python/layers/__init__.py index f67ab78c6..dafa085df 100644 --- a/db/python/layers/__init__.py +++ b/db/python/layers/__init__.py @@ -1,5 +1,6 @@ from db.python.layers.analysis import AnalysisLayer from db.python.layers.assay import AssayLayer +from db.python.layers.audit_log import AuditLogLayer from db.python.layers.base import BaseLayer from db.python.layers.family import FamilyLayer from db.python.layers.participant import ParticipantLayer diff --git a/db/python/layers/analysis.py b/db/python/layers/analysis.py index e3bc1dd1d..67fd18c13 100644 --- a/db/python/layers/analysis.py +++ b/db/python/layers/analysis.py @@ -1,5 +1,5 @@ -from collections import defaultdict import datetime +from collections import defaultdict from typing import Any from api.utils import group_by @@ -7,7 +7,6 @@ from db.python.layers.base import BaseLayer from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.analysis import AnalysisFilter, AnalysisTable -from db.python.tables.project import ProjectId from db.python.tables.sample import SampleTable from db.python.tables.sequencing_group import SequencingGroupFilter from db.python.utils import GenericFilter, get_logger @@ -19,6 +18,7 @@ ProportionalDateTemporalMethod, SequencingGroupInternal, ) +from models.models.project import ProjectId from models.models.sequencing_group import SequencingGroupInternalId ES_ANALYSIS_OBJ_INTRO_DATE = datetime.date(2022, 6, 21) @@ -535,7 +535,6 @@ async def get_sgs_added_by_day_by_es_indices( async def create_analysis( self, analysis: AnalysisInternal, - author: str = None, project: ProjectId = None, ) -> int: """Create a new analysis""" @@ -546,7 +545,6 @@ async def create_analysis( meta=analysis.meta, output=analysis.output, active=analysis.active, - author=author, project=project, ) @@ -570,7 +568,6 @@ async def update_analysis( status: AnalysisStatus, meta: dict[str, Any] = None, output: str | None = None, - author: str | None = None, check_project_id=True, ): """ @@ -587,7 +584,6 @@ async def update_analysis( status=status, meta=meta, output=output, - author=author, ) async def get_analysis_runner_log( diff --git a/db/python/layers/assay.py b/db/python/layers/assay.py index 32dc9aabb..0e512e6e7 100644 --- a/db/python/layers/assay.py +++ b/db/python/layers/assay.py @@ -1,7 +1,7 @@ # pylint: disable=too-many-arguments from typing import Any -from db.python.connect import NoOpAenter +from db.python.utils import NoOpAenter from db.python.layers.base import BaseLayer, Connection from db.python.tables.assay import AssayTable, AssayFilter from db.python.tables.sample import SampleTable diff --git a/db/python/layers/audit_log.py b/db/python/layers/audit_log.py new file mode 100644 index 000000000..1cc351ecc --- /dev/null +++ b/db/python/layers/audit_log.py @@ -0,0 +1,30 @@ +from db.python.layers.base import BaseLayer, Connection +from db.python.tables.audit_log import AuditLogTable +from models.models.audit_log import AuditLogId, AuditLogInternal + + +class AuditLogLayer(BaseLayer): + """Layer for more complex sample logic""" + + def __init__(self, connection: Connection): + super().__init__(connection) + self.alayer: AuditLogTable = AuditLogTable(connection) + + # GET + async def get_for_ids( + self, ids: list[AuditLogId], check_project_id=True + ) -> list[AuditLogInternal]: + """Query for samples""" + if not ids: + return [] + + logs = await self.alayer.get_audit_logs_for_ids(ids) + if check_project_id: + projects = {log.auth_project for log in logs} + await self.ptable.check_access_to_project_ids( + user=self.author, project_ids=projects, readonly=True + ) + + return logs + + # don't put the create here, I don't want it to be public diff --git a/db/python/layers/base.py b/db/python/layers/base.py index 3b8f282eb..28c7575a7 100644 --- a/db/python/layers/base.py +++ b/db/python/layers/base.py @@ -1,4 +1,4 @@ -from db.python.connect import Connection, ProjectPermissionsTable +from db.python.tables.project import Connection, ProjectPermissionsTable class BaseLayer: @@ -6,7 +6,7 @@ class BaseLayer: def __init__(self, connection: Connection): self.connection = connection - self.ptable = ProjectPermissionsTable(self.connection.connection) + self.ptable = ProjectPermissionsTable(connection) @property def author(self): diff --git a/db/python/layers/family.py b/db/python/layers/family.py index d90aa2612..3d77576b0 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -1,6 +1,6 @@ # pylint: disable=used-before-assignment import logging -from typing import List, Union, Optional, Dict +from typing import Dict, List, Optional, Union from db.python.connect import Connection from db.python.layers.base import BaseLayer @@ -8,11 +8,11 @@ from db.python.tables.family import FamilyTable from db.python.tables.family_participant import FamilyParticipantTable from db.python.tables.participant import ParticipantTable -from db.python.tables.project import ProjectId -from db.python.tables.sample import SampleTable, SampleFilter +from db.python.tables.sample import SampleFilter, SampleTable from db.python.utils import GenericFilter -from models.models.family import PedRowInternal, FamilyInternal +from models.models.family import FamilyInternal, PedRowInternal from models.models.participant import ParticipantUpsertInternal +from models.models.project import ProjectId class PedRow: diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index 44d6d4db2..3698e0cc8 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -4,7 +4,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from db.python.connect import NoOpAenter, NotFoundError from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer from db.python.tables.family import FamilyTable @@ -12,9 +11,10 @@ from db.python.tables.participant import ParticipantTable from db.python.tables.participant_phenotype import ParticipantPhenotypeTable from db.python.tables.sample import SampleTable -from db.python.utils import ProjectId, split_generic_terms +from db.python.utils import NoOpAenter, NotFoundError, split_generic_terms from models.models.family import PedRowInternal from models.models.participant import ParticipantInternal, ParticipantUpsertInternal +from models.models.project import ProjectId HPO_REGEX_MATCHER = re.compile(r'HP\:\d+$') @@ -571,7 +571,6 @@ async def get_external_participant_id_to_internal_sequencing_group_id_map( async def upsert_participant( self, participant: ParticipantUpsertInternal, - author: str = None, project: ProjectId = None, check_project_id: bool = True, open_transaction=True, @@ -602,7 +601,6 @@ async def upsert_participant( reported_gender=participant.reported_gender, meta=participant.meta, karyotype=participant.karyotype, - author=author, ) else: @@ -612,7 +610,6 @@ async def upsert_participant( reported_gender=participant.reported_gender, karyotype=participant.karyotype, meta=participant.meta, - author=author, project=project, ) @@ -623,7 +620,6 @@ async def upsert_participant( await slayer.upsert_samples( participant.samples, - author=author, project=project, check_project_id=False, open_transaction=False, @@ -802,7 +798,6 @@ async def add_participant_to_family( maternal_id=maternal_id, affected=affected, notes=None, - author=None, ) @staticmethod diff --git a/db/python/layers/sample.py b/db/python/layers/sample.py index 4b4d2f4f1..dbb3dbcf0 100644 --- a/db/python/layers/sample.py +++ b/db/python/layers/sample.py @@ -2,18 +2,16 @@ from typing import Any from api.utils import group_by -from db.python.connect import NotFoundError from db.python.layers.assay import AssayLayer from db.python.layers.base import BaseLayer, Connection from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.assay import NoOpAenter -from db.python.tables.project import ProjectId, ProjectPermissionsTable -from db.python.tables.sample import SampleTable, SampleFilter -from db.python.utils import GenericFilter +from db.python.tables.project import ProjectPermissionsTable +from db.python.tables.sample import SampleFilter, SampleTable +from db.python.utils import GenericFilter, NotFoundError +from models.models.project import ProjectId from models.models.sample import SampleInternal, SampleUpsertInternal -from models.utils.sample_id_format import ( - sample_id_format_list, -) +from models.utils.sample_id_format import sample_id_format_list class SampleLayer(BaseLayer): @@ -22,7 +20,7 @@ class SampleLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) self.st: SampleTable = SampleTable(connection) - self.pt = ProjectPermissionsTable(connection.connection) + self.pt = ProjectPermissionsTable(connection) self.connection = connection # GETS @@ -220,7 +218,6 @@ async def get_samples_create_date( async def upsert_sample( self, sample: SampleUpsertInternal, - author: str = None, project: ProjectId = None, process_sequencing_groups: bool = True, process_assays: bool = True, @@ -239,7 +236,6 @@ async def upsert_sample( active=True, meta=sample.meta, participant_id=sample.participant_id, - author=author, project=project, ) else: @@ -276,7 +272,6 @@ async def upsert_samples( self, samples: list[SampleUpsertInternal], open_transaction: bool = True, - author: str = None, project: ProjectId = None, check_project_id=True, ) -> list[SampleUpsertInternal]: @@ -300,7 +295,6 @@ async def upsert_samples( for sample in samples: await self.upsert_sample( sample, - author=author, project=project, process_sequencing_groups=False, process_assays=False, @@ -329,20 +323,18 @@ async def merge_samples( self, id_keep: int, id_merge: int, - author=None, check_project_id=True, ): """Merge two samples into one another""" if check_project_id: projects = await self.st.get_project_ids_for_sample_ids([id_keep, id_merge]) await self.ptable.check_access_to_project_ids( - user=author or self.author, project_ids=projects, readonly=False + user=self.author, project_ids=projects, readonly=False ) return await self.st.merge_samples( id_keep=id_keep, id_merge=id_merge, - author=author, ) async def update_many_participant_ids( diff --git a/db/python/layers/search.py b/db/python/layers/search.py index eb0ea2742..3e94a7ca9 100644 --- a/db/python/layers/search.py +++ b/db/python/layers/search.py @@ -1,7 +1,7 @@ import asyncio from typing import List, Optional -from db.python.connect import NotFoundError +from db.python.utils import NotFoundError from db.python.layers.base import BaseLayer, Connection from db.python.tables.family import FamilyTable from db.python.tables.participant import ParticipantTable @@ -28,7 +28,7 @@ class SearchLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) - self.pt = ProjectPermissionsTable(connection.connection) + self.pt = ProjectPermissionsTable(connection) self.connection = connection @staticmethod diff --git a/db/python/layers/seqr.py b/db/python/layers/seqr.py index 2bf1de14d..cad435852 100644 --- a/db/python/layers/seqr.py +++ b/db/python/layers/seqr.py @@ -29,7 +29,7 @@ from db.python.layers.participant import ParticipantLayer from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.analysis import AnalysisFilter -from db.python.tables.project import Project, ProjectPermissionsTable +from db.python.tables.project import Project from db.python.utils import GenericFilter from models.enums import AnalysisStatus @@ -122,8 +122,7 @@ async def sync_dataset( raise ValueError('Seqr synchronisation is not configured in metamist') token = self.generate_seqr_auth_token() - pptable = ProjectPermissionsTable(connection=self.connection.connection) - project = await pptable.get_and_check_access_to_project_for_id( + project = await self.ptable.get_and_check_access_to_project_for_id( self.connection.author, project_id=self.connection.project, readonly=True, diff --git a/db/python/layers/sequencing_group.py b/db/python/layers/sequencing_group.py index 2e7d073f5..e68150c76 100644 --- a/db/python/layers/sequencing_group.py +++ b/db/python/layers/sequencing_group.py @@ -1,6 +1,6 @@ from datetime import date -from db.python.connect import Connection, NotFoundError +from db.python.connect import Connection from db.python.layers.assay import AssayLayer from db.python.layers.base import BaseLayer from db.python.tables.assay import AssayTable, NoOpAenter @@ -9,11 +9,12 @@ SequencingGroupFilter, SequencingGroupTable, ) -from db.python.utils import ProjectId +from db.python.utils import NotFoundError +from models.models.project import ProjectId from models.models.sequencing_group import ( SequencingGroupInternal, - SequencingGroupUpsertInternal, SequencingGroupInternalId, + SequencingGroupUpsertInternal, ) from models.utils.sequencing_group_id_format import sequencing_group_id_format @@ -261,7 +262,6 @@ async def recreate_sequencing_group_with_new_assays( platform=seqgroup.platform, meta={**seqgroup.meta, **meta}, assay_ids=assays, - author=self.author, open_transaction=False, ) diff --git a/db/python/layers/web.py b/db/python/layers/web.py index ee580297d..20e6c82fd 100644 --- a/db/python/layers/web.py +++ b/db/python/layers/web.py @@ -7,12 +7,12 @@ from datetime import date from api.utils import group_by -from db.python.connect import DbBase from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer from db.python.layers.seqr import SeqrLayer from db.python.tables.analysis import AnalysisTable from db.python.tables.assay import AssayTable +from db.python.tables.base import DbBase from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sequencing_group import SequencingGroupTable from models.models import ( @@ -22,6 +22,7 @@ NestedSampleInternal, NestedSequencingGroupInternal, SearchItem, + parse_sql_bool, ) from models.models.web import ProjectSummaryInternal, WebProject @@ -189,7 +190,7 @@ def _project_summary_process_sample_rows( created_date=str(sample_id_start_times.get(s['id'], '')), sequencing_groups=sg_models_by_sample_id.get(s['id'], []), non_sequencing_assays=filtered_assay_models_by_sid.get(s['id'], []), - active=bool(ord(s['active'])), + active=parse_sql_bool(s['active']), ) for s in sample_rows ] @@ -278,7 +279,7 @@ async def get_project_summary( # do initial query to get sample info sampl = SampleLayer(self._connection) sample_query, values = self._project_summary_sample_query(grid_filter) - ptable = ProjectPermissionsTable(self.connection) + ptable = ProjectPermissionsTable(self._connection) project_db = await ptable.get_and_check_access_to_project_for_id( self.author, self.project, readonly=True ) diff --git a/db/python/tables/analysis.py b/db/python/tables/analysis.py index b079a0e42..d5c5d5fb7 100644 --- a/db/python/tables/analysis.py +++ b/db/python/tables/analysis.py @@ -4,16 +4,17 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Set, Tuple -from db.python.connect import DbBase, NotFoundError -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase from db.python.utils import ( GenericFilter, GenericFilterModel, GenericMetaFilter, + NotFoundError, to_db_json, ) from models.enums import AnalysisStatus from models.models.analysis import AnalysisInternal +from models.models.project import ProjectId @dataclasses.dataclass @@ -60,7 +61,6 @@ async def create_analysis( meta: Optional[Dict[str, Any]] = None, output: str = None, active: bool = True, - author: str = None, project: ProjectId = None, ) -> int: """ @@ -73,14 +73,11 @@ async def create_analysis( ('status', status.value), ('meta', to_db_json(meta or {})), ('output', output), - ('author', author or self.author), + ('audit_log_id', await self.audit_log_id()), ('project', project or self.project), ('active', active if active is not None else True), ] - if author is not None: - kv_pairs.append(('on_behalf_of', self.author)) - if status == AnalysisStatus.COMPLETED: kv_pairs.append(('timestamp_completed', datetime.datetime.utcnow())) @@ -110,11 +107,19 @@ async def add_sequencing_groups_to_analysis( """Add samples to an analysis (through the linked table)""" _query = """ INSERT INTO analysis_sequencing_group - (analysis_id, sequencing_group_id) - VALUES (:aid, :sid) + (analysis_id, sequencing_group_id, audit_log_id) + VALUES (:aid, :sid, :audit_log_id) """ - values = map(lambda sid: {'aid': analysis_id, 'sid': sid}, sequencing_group_ids) + audit_log_id = await self.audit_log_id() + values = map( + lambda sid: { + 'aid': analysis_id, + 'sid': sid, + 'audit_log_id': audit_log_id, + }, + sequencing_group_ids, + ) await self.connection.execute_many(_query, list(values)) async def find_sgs_in_joint_call_or_es_index_up_to_date( @@ -140,18 +145,17 @@ async def update_analysis( meta: Dict[str, Any] = None, active: bool = None, output: Optional[str] = None, - author: Optional[str] = None, ): """ Update the status of an analysis, set timestamp_completed if relevant """ fields: Dict[str, Any] = { - 'author': self.author or author, - 'on_behalf_of': self.author, 'analysis_id': analysis_id, + 'on_behalf_of': self.author, + 'audit_log_id': await self.audit_log_id(), } - setters = ['author = :author', 'on_behalf_of = :on_behalf_of'] + setters = ['audit_log_id = :audit_log_id', 'on_behalf_of = :on_behalf_of'] if status: setters.append('status = :status') fields['status'] = status.value @@ -497,8 +501,8 @@ async def get_analysis_runner_log( values['project_ids'] = project_ids if author: - wheres.append('author = :author') - values['author'] = author + wheres.append('audit_log_id = :audit_log_id') + values['audit_log_id'] = await self.audit_log_id() if output_dir: wheres.append('(output = :output OR output LIKE :output_like)') diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py index 14653be01..95837e0f3 100644 --- a/db/python/tables/assay.py +++ b/db/python/tables/assay.py @@ -4,15 +4,17 @@ from collections import defaultdict from typing import Any -from db.python.connect import DbBase, NotFoundError, NoOpAenter -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase from db.python.utils import ( - to_db_json, - GenericFilterModel, GenericFilter, + GenericFilterModel, GenericMetaFilter, + NoOpAenter, + NotFoundError, + to_db_json, ) from models.models.assay import AssayInternal +from models.models.project import ProjectId REPLACEMENT_KEY_INVALID_CHARS = re.compile(r'[^\w\d_]') @@ -100,7 +102,7 @@ async def query( drow = dict(row) project_ids.add(drow.pop('project')) assay = AssayInternal.from_db(drow) - assay.external_ids = seq_eids.get(assay.id, {}) + assay.external_ids = seq_eids.get(assay.id, {}) if assay.id else {} assays.append(assay) return project_ids, assays @@ -135,7 +137,7 @@ async def get_assay_by_id(self, assay_id: int) -> tuple[ProjectId, AssayInternal return pjcts.pop(), assays.pop() async def get_assay_by_external_id( - self, external_sequence_id: str, project: int = None + self, external_sequence_id: str, project: ProjectId | None = None ) -> AssayInternal: """Get assay by EXTERNAL ID""" if not (project or self.project): @@ -212,7 +214,6 @@ async def insert_assay( external_ids: dict[str, str] | None, assay_type: str, meta: dict[str, Any] | None, - author: str | None = None, project: int | None = None, open_transaction: bool = True, ) -> int: @@ -241,8 +242,8 @@ async def insert_assay( _query = """\ INSERT INTO assay - (sample_id, meta, type, author) - VALUES (:sample_id, :meta, :type, :author) + (sample_id, meta, type, audit_log_id) + VALUES (:sample_id, :meta, :type, :audit_log_id) RETURNING id; """ @@ -255,7 +256,7 @@ async def insert_assay( 'sample_id': sample_id, 'meta': to_db_json(meta), 'type': assay_type, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), }, ) @@ -269,16 +270,17 @@ async def insert_assay( _eid_query = """ INSERT INTO assay_external_id - (project, assay_id, external_id, name, author) - VALUES (:project, :assay_id, :external_id, :name, :author); + (project, assay_id, external_id, name, audit_log_id) + VALUES (:project, :assay_id, :external_id, :name, :audit_log_id); """ + audit_log_id = await self.audit_log_id() eid_values = [ { 'project': project or self.project, 'assay_id': id_of_new_assay, 'external_id': eid, 'name': name.lower(), - 'author': author or self.author, + 'audit_log_id': audit_log_id, } for name, eid in external_ids.items() ] @@ -288,7 +290,7 @@ async def insert_assay( return id_of_new_assay async def insert_many_assays( - self, assays: list[AssayInternal], author=None, open_transaction: bool = True + self, assays: list[AssayInternal], open_transaction: bool = True ): """Insert many sequencing, returning no IDs""" with_function = self.connection.transaction if open_transaction else NoOpAenter @@ -303,13 +305,14 @@ async def insert_many_assays( sample_id=assay.sample_id, external_ids=assay.external_ids, meta=assay.meta, - author=author, assay_type=assay.type, open_transaction=False, ) ) return assay_ids + # endregion INSERTS + async def update_assay( self, assay_id: int, @@ -320,15 +323,18 @@ async def update_assay( sample_id: int | None = None, project: ProjectId | None = None, open_transaction: bool = True, - author=None, ): """Update an assay""" with_function = self.connection.transaction if open_transaction else NoOpAenter async with with_function(): - fields = {'assay_id': assay_id, 'author': author or self.author} + audit_log_id = await self.audit_log_id() + fields: dict[str, Any] = { + 'assay_id': assay_id, + 'audit_log_id': audit_log_id, + } - updaters = ['author = :author'] + updaters = ['audit_log_id = :audit_log_id'] if meta is not None: updaters.append('meta = JSON_MERGE_PATCH(COALESCE(meta, "{}"), :meta)') fields['meta'] = to_db_json(meta) @@ -363,7 +369,20 @@ async def update_assay( } if to_delete: + _assay_eid_update_before_delete = """ + UPDATE assay_external_id + SET audit_log_id = :audit_log_id + WHERE assay_id = :assay_id AND name in :names + """ _delete_query = 'DELETE FROM assay_external_id WHERE assay_id = :assay_id AND name in :names' + await self.connection.execute( + _assay_eid_update_before_delete, + { + 'assay_id': assay_id, + 'names': list(to_delete), + 'audit_log_id': audit_log_id, + }, + ) await self.connection.execute( _delete_query, {'assay_id': assay_id, 'names': list(to_delete)}, @@ -375,17 +394,18 @@ async def update_assay( ) _update_query = """\ - INSERT INTO assay_external_id (project, assay_id, external_id, name, author) - VALUES (:project, :assay_id, :external_id, :name, :author) - ON DUPLICATE KEY UPDATE external_id = :external_id, author = :author + INSERT INTO assay_external_id (project, assay_id, external_id, name, audit_log_id) + VALUES (:project, :assay_id, :external_id, :name, :audit_log_id) + ON DUPLICATE KEY UPDATE external_id = :external_id, audit_log_id = :audit_log_id """ + audit_log_id = await self.audit_log_id() values = [ { 'project': project, 'assay_id': assay_id, 'external_id': eid, 'name': name, - 'author': author or self.author, + 'audit_log_id': audit_log_id, } for name, eid in to_update.items() ] @@ -395,13 +415,13 @@ async def update_assay( async def get_assays_by( self, - assay_ids: list[int] = None, - sample_ids: list[int] = None, - assay_types: list[str] = None, - assay_meta: dict[str, Any] = None, - sample_meta: dict[str, Any] = None, - external_assay_ids: list[str] = None, - project_ids: list[int] = None, + assay_ids: list[int] | None = None, + sample_ids: list[int] | None = None, + assay_types: list[str] | None = None, + assay_meta: dict[str, Any] | None = None, + sample_meta: dict[str, Any] | None = None, + external_assay_ids: list[str] | None = None, + project_ids: list[int] | None = None, active: bool = True, ) -> tuple[list[ProjectId], list[AssayInternal]]: """Get sequences by some criteria""" diff --git a/db/python/tables/audit_log.py b/db/python/tables/audit_log.py new file mode 100644 index 000000000..f3f24038e --- /dev/null +++ b/db/python/tables/audit_log.py @@ -0,0 +1,74 @@ +from typing import Any + +from db.python.tables.base import DbBase +from db.python.utils import to_db_json +from models.models.audit_log import AuditLogInternal +from models.models.project import ProjectId + + +class AuditLogTable(DbBase): + """ + Capture Analysis table operations and queries + """ + + table_name = 'audit_log' + + async def get_projects_for_ids(self, audit_log_ids: list[int]) -> set[ProjectId]: + """Get project IDs for sampleIds (mostly for checking auth)""" + _query = """ + SELECT DISTINCT auth_project + FROM audit_log + WHERE id in :audit_log_ids + """ + if len(audit_log_ids) == 0: + raise ValueError('Received no audit log IDs') + rows = await self.connection.fetch_all(_query, {'audit_log_ids': audit_log_ids}) + return {r['project'] for r in rows} + + async def get_audit_logs_for_ids( + self, audit_log_ids: list[int] + ) -> list[AuditLogInternal]: + """Get project IDs for sampleIds (mostly for checking auth)""" + _query = """ + SELECT id, timestamp, author, on_behalf_of, ar_guid, comment, auth_project + FROM audit_log + WHERE id in :audit_log_ids + """ + if len(audit_log_ids) == 0: + raise ValueError('Received no audit log IDs') + rows = await self.connection.fetch_all(_query, {'audit_log_ids': audit_log_ids}) + return [AuditLogInternal.from_db(dict(r)) for r in rows] + + async def create_audit_log( + self, + author: str, + on_behalf_of: str | None, + ar_guid: str | None, + comment: str | None, + project: ProjectId | None, + meta: dict[str, Any] | None = None, + ) -> int: + """ + Create a new audit log entry + """ + + _query = """ + INSERT INTO audit_log + (author, on_behalf_of, ar_guid, comment, auth_project, meta) + VALUES + (:author, :on_behalf_of, :ar_guid, :comment, :project, :meta) + RETURNING id + """ + audit_log_id = await self.connection.fetch_val( + _query, + { + 'author': author, + 'on_behalf_of': on_behalf_of, + 'ar_guid': ar_guid, + 'comment': comment, + 'project': project, + 'meta': to_db_json(meta or {}), + }, + ) + + return audit_log_id diff --git a/db/python/tables/base.py b/db/python/tables/base.py new file mode 100644 index 000000000..82815d10b --- /dev/null +++ b/db/python/tables/base.py @@ -0,0 +1,42 @@ +import databases + +from db.python.connect import Connection +from db.python.utils import InternalError + + +class DbBase: + """Base class for table subclasses""" + + def __init__(self, connection: Connection): + if connection is None: + raise InternalError( + f'No connection was provided to the table {self.__class__.__name__!r}' + ) + if not isinstance(connection, Connection): + raise InternalError( + f'Expected connection type Connection, received {type(connection)}, ' + f'did you mean to call self._connection?' + ) + + self._connection = connection + self.connection: databases.Database = connection.connection + self.author = connection.author + self.project = connection.project + + if self.author is None: + raise InternalError(f'Must provide author to {self.__class__.__name__}') + + async def audit_log_id(self): + """ + Get audit_log ID (or fail otherwise) + """ + return await self._connection.audit_log_id() + + # piped from the connection + + @staticmethod + def escape_like_term(query: str): + """ + Escape meaningful keys when using LIKE with a user supplied input + """ + return query.replace('%', '\\%').replace('_', '\\_') diff --git a/db/python/tables/family.py b/db/python/tables/family.py index 853523b6f..18b13dd72 100644 --- a/db/python/tables/family.py +++ b/db/python/tables/family.py @@ -1,9 +1,10 @@ from collections import defaultdict -from typing import List, Optional, Set, Any, Dict +from typing import Any, Dict, List, Optional, Set -from db.python.connect import DbBase, NotFoundError -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase +from db.python.utils import NotFoundError from models.models.family import FamilyInternal +from models.models.project import ProjectId class FamilyTable(DbBase): @@ -47,7 +48,7 @@ async def get_families( JOIN family_participant ON family.id = family_participant.family_id """ - where.append(f'participant_id IN :pids') + where.append('participant_id IN :pids') values['pids'] = participant_ids if project or self.project: @@ -175,13 +176,12 @@ async def get_family_external_ids_by_participant_ids( async def update_family( self, id_: int, - external_id: str = None, - description: str = None, - coded_phenotype: str = None, - author: str = None, + external_id: str | None = None, + description: str | None = None, + coded_phenotype: str | None = None, ) -> bool: """Update values for a family""" - values: Dict[str, Any] = {'author': author or self.author} + values: Dict[str, Any] = {'audit_log_id': await self.audit_log_id()} if external_id: values['external_id'] = external_id if description: @@ -203,7 +203,6 @@ async def create_family( external_id: str, description: Optional[str], coded_phenotype: Optional[str], - author: str = None, project: ProjectId = None, ) -> int: """ @@ -213,7 +212,7 @@ async def create_family( 'external_id': external_id, 'description': description, 'coded_phenotype': coded_phenotype, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), 'project': project or self.project, } keys = list(updater.keys()) @@ -235,7 +234,6 @@ async def insert_or_update_multiple_families( descriptions: List[str], coded_phenotypes: List[Optional[str]], project: int = None, - author: str = None, ): """Upsert""" updater = [ @@ -243,7 +241,7 @@ async def insert_or_update_multiple_families( 'external_id': eid, 'description': descr, 'coded_phenotype': cph, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), 'project': project or self.project, } for eid, descr, cph in zip(external_ids, descriptions, coded_phenotypes) diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py index 8f2c730d6..8aab115a8 100644 --- a/db/python/tables/family_participant.py +++ b/db/python/tables/family_participant.py @@ -1,9 +1,9 @@ from collections import defaultdict -from typing import Tuple, List, Dict, Optional, Set, Any +from typing import Any, Dict, List, Optional, Set, Tuple -from db.python.connect import DbBase -from db.python.tables.project import ProjectId +from db.python.tables.base import DbBase from models.models.family import PedRowInternal +from models.models.project import ProjectId class FamilyParticipantTable(DbBase): @@ -21,7 +21,6 @@ async def create_row( maternal_id: int, affected: int, notes: str = None, - author=None, ) -> Tuple[int, int]: """ Create a new sample, and add it to database @@ -33,7 +32,7 @@ async def create_row( 'maternal_participant_id': maternal_id, 'affected': affected, 'notes': notes, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), } keys = list(updater.keys()) str_keys = ', '.join(keys) @@ -52,7 +51,6 @@ async def create_row( async def create_rows( self, rows: list[PedRowInternal], - author=None, ): """ Create many rows, dictionaries must have keys: @@ -76,7 +74,7 @@ async def create_rows( 'maternal_participant_id': row.maternal_id, 'affected': row.affected, 'notes': row.notes, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), } remapped_ds_by_keys[tuple(sorted(d.keys()))].append(d) @@ -225,12 +223,27 @@ async def delete_family_participant_row(self, family_id: int, participant_id: in if not participant_id or not family_id: return False + _update_before_delete = """ + UPDATE family_participant + SET audit_log_id = :audit_log_id + WHERE family_id = :family_id AND participant_id = :participant_id + """ + _query = """ -DELETE FROM family_participant -WHERE participant_id = :participant_id -AND family_id = :family_id + DELETE FROM family_participant + WHERE participant_id = :participant_id + AND family_id = :family_id """ + await self.connection.execute( + _update_before_delete, + { + 'family_id': family_id, + 'participant_id': participant_id, + 'audit_log_id': await self.audit_log_id(), + }, + ) + await self.connection.execute( _query, {'family_id': family_id, 'participant_id': participant_id} ) diff --git a/db/python/tables/participant.py b/db/python/tables/participant.py index 0b0b8eb9c..6f326f9cf 100644 --- a/db/python/tables/participant.py +++ b/db/python/tables/participant.py @@ -1,9 +1,10 @@ from collections import defaultdict from typing import Any -from db.python.connect import DbBase, NotFoundError -from db.python.utils import ProjectId, to_db_json +from db.python.tables.base import DbBase +from db.python.utils import NotFoundError, to_db_json from models.models.participant import ParticipantInternal +from models.models.project import ProjectId class ParticipantTable(DbBase): @@ -20,6 +21,7 @@ class ParticipantTable(DbBase): 'karyotype', 'meta', 'project', + 'audit_log_id', ] ) @@ -75,7 +77,6 @@ async def create_participant( reported_gender: str | None, karyotype: str | None, meta: dict | None, - author: str = None, project: ProjectId = None, ) -> int: """ @@ -84,9 +85,11 @@ async def create_participant( if not (project or self.project): raise ValueError('Must provide project to create participant') - _query = f""" -INSERT INTO participant (external_id, reported_sex, reported_gender, karyotype, meta, author, project) -VALUES (:external_id, :reported_sex, :reported_gender, :karyotype, :meta, :author, :project) + _query = """ +INSERT INTO participant + (external_id, reported_sex, reported_gender, karyotype, meta, audit_log_id, project) +VALUES + (:external_id, :reported_sex, :reported_gender, :karyotype, :meta, :audit_log_id, :project) RETURNING id """ @@ -98,7 +101,7 @@ async def create_participant( 'reported_gender': reported_gender, 'karyotype': karyotype, 'meta': to_db_json(meta or {}), - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), 'project': project or self.project, }, ) @@ -110,18 +113,17 @@ async def update_participants( reported_genders: list[str] | None, karyotypes: list[str] | None, metas: list[dict] | None, - author=None, ): """ Update many participants, expects that all lists contain the same number of values. You can't update selective fields on selective samples, if you provide metas, this function will update EVERY participant with the provided meta values. """ - _author = author or self.author - updaters = ['author = :author'] + updaters = ['audit_log_id = :audit_log_id'] + audit_log_id = await self.audit_log_id() values: dict[str, list[Any]] = { 'pid': participant_ids, - 'author': [_author] * len(participant_ids), + 'audit_log_id': [audit_log_id] * len(participant_ids), } if reported_sexes: updaters.append('reported_sex = :reported_sex') @@ -154,13 +156,12 @@ async def update_participant( reported_gender: str | None, karyotype: str | None, meta: dict | None, - author=None, ): """ Update participant """ - updaters = ['author = :author'] - fields = {'pid': participant_id, 'author': author or self.author} + updaters = ['audit_log_id = :audit_log_id'] + fields = {'pid': participant_id, 'audit_log_id': await self.audit_log_id()} if external_id: updaters.append('external_id = :external_id') @@ -248,7 +249,7 @@ async def get_participants_by_families( self, family_ids: list[int] ) -> tuple[set[ProjectId], dict[int, list[ParticipantInternal]]]: """Get list of participants keyed by families, duplicates results""" - _query = f""" + _query = """ SELECT project, fp.family_id, p.id, p.external_id, p.reported_sex, p.reported_gender, p.karyotype, p.meta FROM participant p INNER JOIN family_participant fp ON fp.participant_id = p.id @@ -271,10 +272,11 @@ async def update_many_participant_external_ids( """Update many participant external_ids through the {internal: external} map""" _query = """ UPDATE participant - SET external_id = :external_id + SET external_id = :external_id, audit_log_id = :audit_log_id WHERE id = :participant_id""" + audit_log_id = await self.audit_log_id() mapped_values = [ - {'participant_id': k, 'external_id': v} + {'participant_id': k, 'external_id': v, 'audit_log_id': audit_log_id} for k, v in internal_to_external_id.items() ] await self.connection.execute_many(_query, mapped_values) diff --git a/db/python/tables/participant_phenotype.py b/db/python/tables/participant_phenotype.py index 556617209..ea010588d 100644 --- a/db/python/tables/participant_phenotype.py +++ b/db/python/tables/participant_phenotype.py @@ -1,9 +1,8 @@ -from typing import List, Tuple, Any, Dict - import json from collections import defaultdict +from typing import Any, Dict, List, Tuple -from db.python.connect import DbBase +from db.python.tables.base import DbBase class ParticipantPhenotypeTable(DbBase): @@ -13,27 +12,28 @@ class ParticipantPhenotypeTable(DbBase): table_name = 'participant_phenotype' - async def add_key_value_rows( - self, rows: List[Tuple[int, str, Any]], author: str = None - ) -> None: + async def add_key_value_rows(self, rows: List[Tuple[int, str, Any]]) -> None: """ Create a new sample, and add it to database """ if not rows: return None - _query = f""" -INSERT INTO participant_phenotypes (participant_id, description, value, author, hpo_term) -VALUES (:participant_id, :description, :value, :author, 'DESCRIPTION') + _query = """ +INSERT INTO participant_phenotypes + (participant_id, description, value, audit_log_id, hpo_term) +VALUES + (:participant_id, :description, :value, :audit_log_id, 'DESCRIPTION') ON DUPLICATE KEY UPDATE - description=:description, value=:value, author=:author + description=:description, value=:value, audit_log_id=:audit_log_id """ + audit_log_id = await self.audit_log_id() formatted_rows = [ { 'participant_id': r[0], 'description': r[1], 'value': json.dumps(r[2]), - 'author': author or self.author, + 'audit_log_id': audit_log_id, } for r in rows ] diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 5ff51010f..b3f0bec55 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -5,16 +5,16 @@ from databases import Database from api.settings import is_all_access +from db.python.connect import Connection, SMConnections from db.python.utils import ( Forbidden, InternalError, NoProjectAccess, NotFoundError, - ProjectId, get_logger, to_db_json, ) -from models.models.project import Project +from models.models.project import Project, ProjectId logger = get_logger() @@ -38,14 +38,59 @@ def get_project_group_name(project_name: str, readonly: bool) -> str: return f'{project_name}-read' return f'{project_name}-write' - def __init__(self, connection: Database, allow_full_access=None): - if not isinstance(connection, Database): + def __init__( + self, + connection: Connection | None, + allow_full_access: bool | None = None, + database_connection: Database | None = None, + ): + self._connection = connection + if not database_connection and not connection: raise ValueError( - f'Invalid type connection, expected Database, got {type(connection)}, ' - 'did you forget to call connection.connection?' + 'Must call project permissions table with either a direct ' + 'database_connection or a fully formed connection' ) - self.connection: Database = connection - self.gtable = GroupTable(connection, allow_full_access=allow_full_access) + self.connection: Database = database_connection or connection.connection + self.gtable = GroupTable(self.connection, allow_full_access=allow_full_access) + + @staticmethod + async def get_project_connection( + *, + author: str, + project_name: str, + readonly: bool, + ar_guid: str, + on_behalf_of: str | None = None, + ): + """Get a db connection from a project and user""" + # maybe it makes sense to perform permission checks here too + logger.debug(f'Authenticate connection to {project_name} with {author!r}') + + conn = await SMConnections.get_made_connection() + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + + project = await pt.get_and_check_access_to_project_for_name( + user=author, project_name=project_name, readonly=readonly + ) + + return Connection( + connection=conn, + author=author, + project=project.id, + readonly=readonly, + on_behalf_of=on_behalf_of, + ar_guid=ar_guid, + ) + + async def audit_log_id(self): + """ + Generate (or return) a audit_log_id by inserting a row into the database + """ + if not self._connection: + raise ValueError( + 'Cannot call audit_log_id without a fully formed connection' + ) + return await self._connection.audit_log_id() # region UNPROTECTED_GETS @@ -310,21 +355,24 @@ async def create_project( await self.check_project_creator_permissions(author) async with self.connection.transaction(): + audit_log_id = await self.audit_log_id() read_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=True) + self.get_project_group_name(project_name, readonly=True), + audit_log_id=audit_log_id, ) write_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=False) + self.get_project_group_name(project_name, readonly=False), + audit_log_id=audit_log_id, ) _query = """\ - INSERT INTO project (name, dataset, author, read_group_id, write_group_id) - VALUES (:name, :dataset, :author, :read_group_id, :write_group_id) + INSERT INTO project (name, dataset, audit_log_id, read_group_id, write_group_id) + VALUES (:name, :dataset, :audit_log_id, :read_group_id, :write_group_id) RETURNING ID""" values = { 'name': project_name, 'dataset': dataset_name, - 'author': author, + 'audit_log_id': await self.audit_log_id(), 'read_group_id': read_group_id, 'write_group_id': write_group_id, } @@ -342,9 +390,12 @@ async def update_project(self, project_name: str, update: dict, author: str): meta = update.get('meta') - fields: Dict[str, Any] = {'author': author, 'name': project_name} + fields: Dict[str, Any] = { + 'audit_log_id': await self.audit_log_id(), + 'name': project_name, + } - setters = ['author = :author'] + setters = ['audit_log_id = :audit_log_id'] if meta is not None and len(meta) > 0: fields['meta'] = to_db_json(meta) @@ -443,7 +494,9 @@ async def set_group_members(self, group_name: str, members: list[str], author: s f'User {author} does not have permission to add members to group {group_name}' ) group_id = await self.gtable.get_group_name_from_id(group_name) - await self.gtable.set_group_members(group_id, members, author=author) + await self.gtable.set_group_members( + group_id, members, audit_log_id=await self.audit_log_id() + ) # endregion CREATE / UPDATE @@ -573,16 +626,20 @@ async def check_which_groups_member_has( ) return set(r['gid'] for r in results) - async def create_group(self, name: str) -> int: + async def create_group(self, name: str, audit_log_id: int) -> int: """Create a new group""" _query = """ - INSERT INTO `group` (name) - VALUES (:name) + INSERT INTO `group` (name, audit_log_id) + VALUES (:name, :audit_log_id) RETURNING id """ - return await self.connection.fetch_val(_query, {'name': name}) + return await self.connection.fetch_val( + _query, {'name': name, 'audit_log_id': audit_log_id} + ) - async def set_group_members(self, group_id: int, members: list[str], author: str): + async def set_group_members( + self, group_id: int, members: list[str], audit_log_id: int + ): """ Set group members for a group (by id) """ @@ -595,11 +652,15 @@ async def set_group_members(self, group_id: int, members: list[str], author: str ) await self.connection.execute_many( """ - INSERT INTO group_member (group_id, member, author) - VALUES (:group_id, :member, :author) + INSERT INTO group_member (group_id, member, audit_log_id) + VALUES (:group_id, :member, :audit_log_id) """, [ - {'group_id': group_id, 'member': member, 'author': author} + { + 'group_id': group_id, + 'member': member, + 'audit_log_id': audit_log_id, + } for member in members ], ) diff --git a/db/python/tables/sample.py b/db/python/tables/sample.py index 0f15645a2..3457a2333 100644 --- a/db/python/tables/sample.py +++ b/db/python/tables/sample.py @@ -1,16 +1,17 @@ import asyncio -from datetime import date -from typing import Iterable, Any import dataclasses +from datetime import date +from typing import Any, Iterable -from db.python.connect import DbBase, NotFoundError +from db.python.tables.base import DbBase from db.python.utils import ( - to_db_json, - GenericFilterModel, GenericFilter, + GenericFilterModel, GenericMetaFilter, + NotFoundError, + to_db_json, ) -from db.python.tables.project import ProjectId +from models.models.project import ProjectId from models.models.sample import SampleInternal, sample_id_format @@ -47,6 +48,7 @@ class SampleTable(DbBase): 'active', 'type', 'project', + 'audit_log_id', ] # region GETS @@ -139,7 +141,6 @@ async def insert_sample( active: bool, meta: dict | None, participant_id: int | None, - author=None, project=None, ) -> int: """ @@ -152,7 +153,7 @@ async def insert_sample( ('meta', to_db_json(meta or {})), ('type', sample_type), ('active', active), - ('author', author or self.author), + ('audit_log_id', await self.audit_log_id()), ('project', project or self.project), ] @@ -179,16 +180,13 @@ async def update_sample( participant_id: int | None, external_id: str | None, type_: str | None, - author: str = None, active: bool = None, ): """Update a single sample""" - values: dict[str, Any] = { - 'author': author or self.author, - } + values: dict[str, Any] = {'audit_log_id': await self.audit_log_id()} fields = [ - 'author = :author', + 'audit_log_id = :audit_log_id', ] if participant_id: values['participant_id'] = participant_id @@ -220,7 +218,6 @@ async def merge_samples( self, id_keep: int = None, id_merge: int = None, - author: str = None, ): """Merge two samples together""" sid_merge = sample_id_format(id_merge) @@ -261,34 +258,43 @@ def dict_merge(meta1, meta2): meta_original.get('merged_from'), sid_merge ) meta: dict[str, Any] = dict_merge(meta_original, sample_merge.meta) - + audit_log_id = await self.audit_log_id() values: dict[str, Any] = { 'sample': { 'id': id_keep, - 'author': author or self.author, + 'audit_log_id': audit_log_id, 'meta': to_db_json(meta), }, - 'ids': {'id_keep': id_keep, 'id_merge': id_merge}, + 'ids': { + 'id_keep': id_keep, + 'id_merge': id_merge, + 'audit_log_id': audit_log_id, + }, } _query = """ UPDATE sample - SET author = :author, + SET audit_log_id = :audit_log_id, meta = :meta WHERE id = :id """ - _query_seqs = f""" - UPDATE sample_sequencing - SET sample_id = :id_keep + _query_seqs = """ + UPDATE assay + SET sample_id = :id_keep, audit_log_id = :audit_log_id WHERE sample_id = :id_merge """ # TODO: merge sequencing groups I guess? - _query_analyses = f""" - UPDATE analysis_sample - SET sample_id = :id_keep + _query_analyses = """ + UPDATE analysis_sequencing_group + SET sample_id = :id_keep, audit_log_id = :audit_log_id WHERE sample_id = :id_merge """ - _del_sample = f""" + _query_update_sample_with_audit_log = """ + UPDATE sample + SET audit_log_id = :audit_log_id + WHERE id = :id_merge + """ + _del_sample = """ DELETE FROM sample WHERE id = :id_merge """ @@ -297,11 +303,16 @@ def dict_merge(meta1, meta2): await self.connection.execute(_query, {**values['sample']}) await self.connection.execute(_query_seqs, {**values['ids']}) await self.connection.execute(_query_analyses, {**values['ids']}) - await self.connection.execute(_del_sample, {'id_merge': id_merge}) + await self.connection.execute( + _query_update_sample_with_audit_log, + {'id_merge': id_merge, 'audit_log_id': audit_log_id}, + ) + await self.connection.execute( + _del_sample, {'id_merge': id_merge, 'audit_log_id': audit_log_id} + ) project, new_sample = await self.get_sample_by_id(id_keep) new_sample.project = project - new_sample.author = author or self.author return new_sample @@ -312,9 +323,15 @@ async def update_many_participant_ids( Update participant IDs for many samples Expected len(ids) == len(participant_ids) """ - _query = 'UPDATE sample SET participant_id=:participant_id WHERE id = :id' + _query = """ + UPDATE sample + SET participant_id=:participant_id, audit_log_id = :audit_log_id + WHERE id = :id + """ + audit_log_id = await self.audit_log_id() values = [ - {'id': i, 'participant_id': pid} for i, pid in zip(ids, participant_ids) + {'id': i, 'participant_id': pid, 'audit_log_id': audit_log_id} + for i, pid in zip(ids, participant_ids) ] await self.connection.execute_many(_query, values) @@ -420,6 +437,7 @@ async def get_history_of_sample(self, id_: int): 'type', 'project', 'author', + 'audit_log_id', ] keys_str = ', '.join(keys) _query = f'SELECT {keys_str} FROM sample FOR SYSTEM_TIME ALL WHERE id = :id' diff --git a/db/python/tables/sequencing_group.py b/db/python/tables/sequencing_group.py index 1be71e29a..1b82da962 100644 --- a/db/python/tables/sequencing_group.py +++ b/db/python/tables/sequencing_group.py @@ -4,14 +4,16 @@ from datetime import date from typing import Any -from db.python.connect import DbBase, NoOpAenter, NotFoundError +from db.python.tables.base import DbBase from db.python.utils import ( GenericFilter, GenericFilterModel, GenericMetaFilter, - ProjectId, + NoOpAenter, + NotFoundError, to_db_json, ) +from models.models.project import ProjectId from models.models.sequencing_group import ( SequencingGroupInternal, SequencingGroupInternalId, @@ -276,7 +278,6 @@ async def create_sequencing_group( platform: str, assay_ids: list[int], meta: dict = None, - author: str = None, open_transaction=True, ) -> int: """Create sequence group""" @@ -307,16 +308,17 @@ async def create_sequencing_group( _query = """ INSERT INTO sequencing_group - (sample_id, type, technology, platform, meta, author, archived) - VALUES (:sample_id, :type, :technology, :platform, :meta, :author, false) + (sample_id, type, technology, platform, meta, audit_log_id, archived) + VALUES + (:sample_id, :type, :technology, :platform, :meta, :audit_log_id, false) RETURNING id; """ _seqg_linker_query = """ INSERT INTO sequencing_group_assay - (sequencing_group_id, assay_id, author) + (sequencing_group_id, assay_id, audit_log_id) VALUES - (:seqgroup, :assayid, :author) + (:seqgroup, :assayid, :audit_log_id) """ values = { @@ -339,13 +341,13 @@ async def create_sequencing_group( id_of_seq_group = await self.connection.fetch_val( _query, - {**values, 'author': author or self.author}, + {**values, 'audit_log_id': await self.audit_log_id()}, ) assay_id_insert_values = [ { 'seqgroup': id_of_seq_group, 'assayid': s, - 'author': author or self.author, + 'audit_log_id': await self.audit_log_id(), } for s in assay_ids ] @@ -361,9 +363,10 @@ async def update_sequencing_group( """ Update meta / platform on sequencing_group """ - updaters = [] + updaters = ['audit_log_id = :audit_log_id'] values: dict[str, Any] = { 'seqgid': sequencing_group_id, + 'audit_log_id': await self.audit_log_id(), } if meta: @@ -388,21 +391,28 @@ async def archive_sequencing_groups(self, sequencing_group_id: list[int]): """ _query = """ UPDATE sequencing_group - SET archived = 1, author = :author + SET archived = 1, audit_log_id = :audit_log_id WHERE id = :sequencing_group_id; """ # do this so we can reuse the sequencing_group_ids _external_id_query = """ UPDATE sequencing_group_external_id - SET nullIfInactive = NULL + SET nullIfInactive = NULL, audit_log_id = :audit_log_id WHERE sequencing_group_id = :sequencing_group_id; """ await self.connection.execute( - _query, {'sequencing_group_id': sequencing_group_id, 'author': self.author} + _query, + { + 'sequencing_group_id': sequencing_group_id, + 'audit_log_id': await self.audit_log_id(), + }, ) await self.connection.execute( _external_id_query, - {'sequencing_group_id': sequencing_group_id}, + { + 'sequencing_group_id': sequencing_group_id, + 'audit_log_id': await self.audit_log_id(), + }, ) async def get_type_numbers_for_project(self, project) -> dict[str, int]: diff --git a/db/python/utils.py b/db/python/utils.py index 847ad69c1..4431ca587 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -7,7 +7,6 @@ from typing import Any, Generic, Sequence, TypeVar T = TypeVar('T') -ProjectId = int levels_map = {'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING': logging.WARNING} diff --git a/models/base.py b/models/base.py index 389c38a56..2b728e2af 100644 --- a/models/base.py +++ b/models/base.py @@ -7,3 +7,15 @@ class SMBase(BaseModel): """Base object for all models""" + + +def parse_sql_bool(val: str | int | bytes) -> bool | None: + """Parse a string from a sql bool""" + if val is None: + return None + if isinstance(val, int): + return bool(val) + if isinstance(val, bytes): + return bool(ord(val)) + + raise ValueError(f'Unknown type for active: {type(val)}') diff --git a/models/models/__init__.py b/models/models/__init__.py index aa7cf5324..4e52b2bd2 100644 --- a/models/models/__init__.py +++ b/models/models/__init__.py @@ -1,3 +1,4 @@ +from models.base import parse_sql_bool from models.models.analysis import ( Analysis, AnalysisInternal, @@ -8,11 +9,15 @@ ProportionalDateTemporalMethod, SequencingGroupSizeModel, ) -from models.models.assay import ( - Assay, - AssayInternal, - AssayUpsert, - AssayUpsertInternal, +from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal +from models.models.audit_log import AuditLogId, AuditLogInternal +from models.models.billing import ( + BillingColumn, + BillingCostBudgetRecord, + BillingCostDetailsRecord, + BillingRowRecord, + BillingTotalCostQueryModel, + BillingTotalCostRecord, ) from models.models.family import ( Family, @@ -29,7 +34,7 @@ ParticipantUpsert, ParticipantUpsertInternal, ) -from models.models.project import Project +from models.models.project import Project, ProjectId from models.models.sample import ( NestedSample, NestedSampleInternal, @@ -62,11 +67,3 @@ ProjectSummaryInternal, WebProject, ) -from models.models.billing import ( - BillingRowRecord, - BillingTotalCostRecord, - BillingTotalCostQueryModel, - BillingColumn, - BillingCostBudgetRecord, - BillingCostDetailsRecord, -) diff --git a/models/models/audit_log.py b/models/models/audit_log.py new file mode 100644 index 000000000..3a2f2e69c --- /dev/null +++ b/models/models/audit_log.py @@ -0,0 +1,25 @@ +import datetime + +from models.base import SMBase +from models.models.project import ProjectId + +AuditLogId = int + + +class AuditLogInternal(SMBase): + """ + Model for audit_log + """ + + id: AuditLogId + timestamp: datetime.datetime + author: str + auth_project: ProjectId + on_behalf_of: str | None + ar_guid: str | None + comment: str | None + + @staticmethod + def from_db(d: dict): + """Take DB mapping object, and return SampleSequencing""" + return AuditLogInternal(**d) diff --git a/models/models/participant.py b/models/models/participant.py index ab843e343..de0dc6ef5 100644 --- a/models/models/participant.py +++ b/models/models/participant.py @@ -1,8 +1,8 @@ import json -from db.python.utils import ProjectId from models.base import OpenApiGenNoneType, SMBase from models.models.family import FamilySimple, FamilySimpleInternal +from models.models.project import ProjectId from models.models.sample import ( NestedSample, NestedSampleInternal, @@ -22,6 +22,8 @@ class ParticipantInternal(SMBase): karyotype: str | None = None meta: dict + audit_log_id: int | None = None + @classmethod def from_db(cls, data: dict): """Convert from db keys, mainly converting parsing meta""" diff --git a/models/models/project.py b/models/models/project.py index 4372a086d..9ca19542f 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -3,11 +3,13 @@ from models.base import SMBase +ProjectId = int + class Project(SMBase): """Row for project in 'project' table""" - id: Optional[int] = None + id: Optional[ProjectId] = None name: Optional[str] = None dataset: Optional[str] = None meta: Optional[dict] = None diff --git a/models/models/sample.py b/models/models/sample.py index a41c06463..5f183ff2e 100644 --- a/models/models/sample.py +++ b/models/models/sample.py @@ -1,6 +1,6 @@ import json -from models.base import OpenApiGenNoneType, SMBase +from models.base import OpenApiGenNoneType, SMBase, parse_sql_bool from models.models.assay import Assay, AssayInternal, AssayUpsert, AssayUpsertInternal from models.models.sequencing_group import ( NestedSequencingGroup, @@ -31,9 +31,8 @@ def from_db(d: dict): _id = d.pop('id', None) type_ = d.pop('type', None) meta = d.pop('meta', None) - active = d.pop('active', None) - if active is not None: - active = bool(ord(active)) + active = parse_sql_bool(d.pop('active', None)) + if meta: if isinstance(meta, bytes): meta = meta.decode() diff --git a/models/models/search.py b/models/models/search.py index cceb68a92..08b3a6d30 100644 --- a/models/models/search.py +++ b/models/models/search.py @@ -1,7 +1,7 @@ -from db.python.utils import ProjectId from models.base import SMBase from models.enums.search import SearchResponseType from models.enums.web import MetaSearchEntityPrefix +from models.models.project import ProjectId class SearchResponseData(SMBase): diff --git a/openapi-templates/api_client.mustache b/openapi-templates/api_client.mustache index b083d401d..d78bf1909 100644 --- a/openapi-templates/api_client.mustache +++ b/openapi-templates/api_client.mustache @@ -11,6 +11,8 @@ import typing from urllib.parse import quote from urllib3.fields import RequestField +from cpg_utils.config import try_get_ar_guid + {{#tornado}} import tornado.gen {{/tornado}} @@ -67,6 +69,9 @@ class ApiClient(object): self.rest_client = rest.RESTClientObject(configuration) self.default_headers = {} + ar_guid = try_get_ar_guid() + if ar_guid: + self.default_headers['sm-ar-guid'] = ar_guid if header_name is not None: self.default_headers[header_name] = header_value self.cookie = cookie diff --git a/test/test_analysis.py b/test/test_analysis.py index 1676c6f6a..be7a15e90 100644 --- a/test/test_analysis.py +++ b/test/test_analysis.py @@ -152,7 +152,7 @@ async def test_get_analysis(self): project=1, meta={}, active=True, - author='testuser', + author=None, ) ] diff --git a/test/test_assay.py b/test/test_assay.py index cde5f242a..d53e79eb8 100644 --- a/test/test_assay.py +++ b/test/test_assay.py @@ -2,7 +2,7 @@ from pymysql.err import IntegrityError -from db.python.connect import NotFoundError +from db.python.utils import NotFoundError from db.python.enum_tables import AssayTypeTable from db.python.layers.assay import AssayLayer from db.python.layers.sample import SampleLayer diff --git a/test/test_audit_log.py b/test/test_audit_log.py new file mode 100644 index 000000000..56b758585 --- /dev/null +++ b/test/test_audit_log.py @@ -0,0 +1,30 @@ +from test.testbase import DbIsolatedTest, run_as_sync + +from db.python.layers.sample import SampleLayer +from models.models.sample import SampleUpsertInternal + + +class TestChangelog(DbIsolatedTest): + """Test audit_log""" + + @run_as_sync + async def test_insert_sample(self): + """ + Test inserting a sample, and check that the audit_log_id reflects the current + change + """ + slayer = SampleLayer(self.connection) + sample = await slayer.upsert_sample( + SampleUpsertInternal( + external_id='Test01', + type='blood', + active=True, + meta={'meta': 'meta ;)'}, + ) + ) + + sample_cl_id = await self.connection.connection.fetch_val( + 'SELECT audit_log_id FROM sample WHERE id = :sid', {'sid': sample.id} + ) + + self.assertEqual(await self.audit_log_id(), sample_cl_id) diff --git a/test/test_project_groups.py b/test/test_project_groups.py index 06667512d..d02e884f1 100644 --- a/test/test_project_groups.py +++ b/test/test_project_groups.py @@ -21,7 +21,7 @@ async def setUp(self): super().setUp() # specifically required to test permissions - self.pttable = ProjectPermissionsTable(self.connection.connection, False) + self.pttable = ProjectPermissionsTable(self.connection, False) async def _add_group_member_direct(self, group_name): """ @@ -33,13 +33,13 @@ async def _add_group_member_direct(self, group_name): ) await self.connection.connection.execute( """ - INSERT INTO group_member (group_id, member, author) - VALUES (:group_id, :member, :author); + INSERT INTO group_member (group_id, member, audit_log_id) + VALUES (:group_id, :member, :audit_log_id); """, { 'group_id': members_admin_group, 'member': self.author, - 'author': self.author, + 'audit_log_id': await self.audit_log_id(), }, ) @@ -73,7 +73,7 @@ async def test_group_set_members_succeeded(self): await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) g = str(uuid.uuid4()) - await self.pttable.gtable.create_group(g) + await self.pttable.gtable.create_group(g, await self.audit_log_id()) self.assertFalse( await self.pttable.gtable.check_if_member_in_group_name(g, 'user1') @@ -99,7 +99,7 @@ async def test_check_which_groups_member_is_missing(self): await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) group = str(uuid.uuid4()) - gid = await self.pttable.gtable.create_group(group) + gid = await self.pttable.gtable.create_group(group, await self.audit_log_id()) present_gids = await self.pttable.gtable.check_which_groups_member_has( {gid}, self.author ) @@ -113,10 +113,12 @@ async def test_check_which_groups_member_is_missing_none(self): await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) group = str(uuid.uuid4()) - gid = await self.pttable.gtable.create_group(group) - await self.pttable.gtable.set_group_members(gid, [self.author], self.author) + gid = await self.pttable.gtable.create_group(group, await self.audit_log_id()) + await self.pttable.gtable.set_group_members( + gid, [self.author], audit_log_id=await self.audit_log_id() + ) present_gids = await self.pttable.gtable.check_which_groups_member_has( - {gid}, self.author + group_ids={gid}, member=self.author ) missing_gids = {gid} - present_gids @@ -160,7 +162,7 @@ async def setUp(self): super().setUp() # specifically required to test permissions - self.pttable = ProjectPermissionsTable(self.connection.connection, False) + self.pttable = ProjectPermissionsTable(self.connection, False) async def _add_group_member_direct(self, group_name): """ @@ -172,13 +174,13 @@ async def _add_group_member_direct(self, group_name): ) await self.connection.connection.execute( """ - INSERT INTO group_member (group_id, member, author) - VALUES (:group_id, :member, :author); + INSERT INTO group_member (group_id, member, audit_log_id) + VALUES (:group_id, :member, :audit_log_id); """, { 'group_id': members_admin_group, 'member': self.author, - 'author': self.author, + 'audit_log_id': await self.audit_log_id(), }, ) diff --git a/test/testbase.py b/test/testbase.py index a08211bef..04c8e1d2d 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -9,6 +9,7 @@ from functools import wraps from typing import Dict +import databases.core import nest_asyncio from pymysql import IntegrityError from testcontainers.mysql import MySqlContainer @@ -23,6 +24,7 @@ SMConnections, ) from db.python.tables.project import ProjectPermissionsTable +from models.models.project import ProjectId # use this to determine where the db directory is relatively, # as pycharm runs in "test/" folder, and GH runs them in git root @@ -74,8 +76,10 @@ class DbTest(unittest.TestCase): # store connections here, so they can be created PER-CLASS # and don't get recreated per test. dbs: Dict[str, MySqlContainer] = {} - connections: Dict[str, Connection] = {} + connections: Dict[str, databases.Database] = {} author: str + project_id: ProjectId + project_name: str @classmethod def setUpClass(cls) -> None: @@ -127,23 +131,28 @@ async def setup(): ) await sm_db.connect() cls.author = 'testuser' - connection = Connection( + + cls.connections[cls.__name__] = sm_db + formed_connection = Connection( connection=sm_db, - project=1, author=cls.author, + readonly=False, + on_behalf_of=None, + ar_guid=None, + project=None, ) - cls.connections[cls.__name__] = connection - ppt = ProjectPermissionsTable( - connection.connection, allow_full_access=True + connection=formed_connection, + allow_full_access=True, ) cls.project_name = 'test' cls.project_id = await ppt.create_project( project_name=cls.project_name, dataset_name=cls.project_name, - author='testuser', + author=cls.author, check_permissions=False, ) + formed_connection.project = cls.project_id except subprocess.CalledProcessError as e: logging.exception(e) @@ -170,9 +179,17 @@ def tearDownClass(cls) -> None: db.stop() def setUp(self) -> None: - self.project_id = 1 - self.project_name = 'test' - self.connection = self.connections[self.__class__.__name__] + self._connection = self.connections[self.__class__.__name__] + # create a connection on each test so we can generate a new + # audit_log ID for each test + self.connection = Connection( + connection=self._connection, + project=self.project_id, + author=self.author, + readonly=False, + ar_guid=None, + on_behalf_of=None, + ) @run_as_sync async def run_graphql_query(self, query, variables=None): @@ -203,6 +220,10 @@ async def run_graphql_query_async(self, query, variables=None): raise value.errors[0] return value.data + async def audit_log_id(self): + """Get audit_log_id for the test""" + return await self.connection.audit_log_id() + class DbIsolatedTest(DbTest): """