From bb61ae7c3d15a6111e82caeac9ad2fd28cb9ee81 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Thu, 14 Mar 2024 15:58:14 +1100 Subject: [PATCH 01/29] add initial db migration for new project_groups table --- db/project.xml | 74 +++++++++++ db/python/tables/project.py | 257 ++++++++++++++++-------------------- db/python/utils.py | 20 ++- models/models/group.py | 19 +++ models/models/project.py | 6 +- scripts/sync_seqr.py | 10 +- 6 files changed, 227 insertions(+), 159 deletions(-) create mode 100644 models/models/group.py diff --git a/db/project.xml b/db/project.xml index f9550b803..beb634abb 100644 --- a/db/project.xml +++ b/db/project.xml @@ -1281,4 +1281,78 @@ ALTER TABLE cohort_sequencing_group ADD SYSTEM VERSIONING; ALTER TABLE analysis_cohort ADD SYSTEM VERSIONING; + + + SET @@system_versioning_alter_history = 1; + + + + + + + + + + + + + + + + + -- copy across read and write roles from existing project table + INSERT INTO project_groups ( + SELECT + id AS project_id, + read_group_id AS group_id, + 'read' AS role, + audit_log_id + FROM project + WHERE read_group_id IS NOT NULL + + UNION ALL + + SELECT + id AS project_id, + write_group_id AS group_id, + 'write' AS role, + audit_log_id + FROM project + WHERE write_group_id IS NOT NULL + ) + + + + + + + + + + + + diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 4a8854a3d..ad8e14811 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -8,12 +8,12 @@ from db.python.connect import Connection, SMConnections from db.python.utils import ( Forbidden, - InternalError, NoProjectAccess, NotFoundError, get_logger, to_db_json, ) +from models.models.group import GroupProjectRole from models.models.project import Project, ProjectId logger = get_logger() @@ -30,13 +30,12 @@ class ProjectPermissionsTable: table_name = 'project' @staticmethod - def get_project_group_name(project_name: str, readonly: bool) -> str: + def get_project_group_name(project_name: str, role: GroupProjectRole) -> str: """ Get group name for a project, for readonly / write """ - if readonly: - return f'{project_name}-read' - return f'{project_name}-write' + + return f'{project_name}-{role.name}' def __init__( self, @@ -102,7 +101,7 @@ async def _get_project_rows_internal(self): Internally cached get_project_rows """ _query = """ - SELECT id, name, meta, dataset, read_group_id, write_group_id + SELECT id, name, meta, dataset FROM project """ rows = await self.connection.fetch_all(_query) @@ -114,9 +113,13 @@ async def _get_project_id_map(self): """ return {p.id: p for p in await self._get_project_rows_internal()} - async def _get_project_name_map(self) -> Dict[str, int]: + async def _get_project_name_map(self) -> Dict[str, int | None]: """Get {project_name: project_id} map""" - return {p.name: p.id for p in await self._get_project_rows_internal()} + return { + p.name: p.id + for p in await self._get_project_rows_internal() + if p.name is not None + } async def _get_project_by_id(self, project_id: ProjectId) -> Project: """Get project by id""" @@ -153,25 +156,38 @@ async def get_all_projects(self, author: str): return await self._get_project_rows_internal() async def get_projects_accessible_by_user( - self, author: str, readonly=True + self, + user: str, + allowed_roles: list[GroupProjectRole], + project_id_filter: list[int] | None, ) -> list[Project]: """ Get projects that are accessible by the specified user """ - assert author - if self.gtable.allow_full_access: - return await self._get_project_rows_internal() - group_name = 'read_group_id' if readonly else 'write_group_id' - _query = f""" - SELECT p.id + parameters: dict[str, str | list[str] | list[int]] = { + 'user': user, + 'allowed_roles': list(role.name for role in allowed_roles), + } + + project_id_filter_str = '' + if project_id_filter is not None: + parameters['project_ids'] = project_id_filter + project_id_filter_str = 'AND p.id in :project_ids' + + query = f""" + SELECT DISTINCT p.id FROM project p - INNER JOIN group_member gm ON gm.group_id = p.{group_name} - WHERE gm.member = :author + INNER JOIN project_groups pg + ON pg.project_id = p.id + AND pg.role IN :allowed_roles + INNER JOIN group_member gm ON gm.group_id = pg.group_id + WHERE gm.member = :user + {project_id_filter_str} """ - relevant_project_ids = await self.connection.fetch_all( - _query, {'author': author} - ) + + relevant_project_ids = await self.connection.fetch_all(query, parameters) + projects = await self._get_projects_by_ids( [p['id'] for p in relevant_project_ids] ) @@ -179,35 +195,43 @@ async def get_projects_accessible_by_user( return projects async def get_and_check_access_to_project_for_id( - self, user: str, project_id: ProjectId, readonly: bool - ) -> Project: + self, + user: str, + project_id: ProjectId, + allowed_roles: list[GroupProjectRole], + raise_exception: bool = True, + ) -> Project | None: """Get project by id""" project = await self._get_project_by_id(project_id) - has_access = await self.gtable.check_if_member_in_group( - group_id=project.read_group_id if readonly else project.write_group_id, - member=user, + + projects = await self.get_projects_accessible_by_user( + user, allowed_roles, [project_id] ) + has_access = len(projects) == 1 if not has_access: - raise NoProjectAccess([project.name], readonly=readonly, author=user) + if raise_exception: + raise NoProjectAccess( + [project.name], allowed_roles=allowed_roles, author=user + ) + return None return project async def get_and_check_access_to_project_for_name( - self, user: str, project_name: str, readonly: bool - ) -> Project: + self, + user: str, + project_name: str, + allowed_roles: list[GroupProjectRole], + raise_exception: bool = True, + ) -> Project | None: """Get project by name + perform access checks""" project = await self._get_project_by_name(project_name) - has_access = await self.gtable.check_if_member_in_group( - group_id=project.read_group_id if readonly else project.write_group_id, - member=user, + return await self.get_and_check_access_to_project_for_id( + user, project.id, allowed_roles, raise_exception ) - if not has_access: - raise NoProjectAccess([project.name], readonly=readonly, author=user) - - return project async def get_and_check_access_to_projects_for_names( - self, user: str, project_names: list[str], readonly: bool + self, user: str, project_names: list[str], allowed_roles: list[GroupProjectRole] ): """Get projects by names + perform access checks""" project_name_map = await self._get_project_name_map() @@ -216,19 +240,22 @@ async def get_and_check_access_to_projects_for_names( missing_project_names = set(project_names) - set(project_name_map.keys()) if missing_project_names: raise NotFoundError( - f'Could not find projects {", ".join(missing_project_names)}' + f'Could not find projects {', '.join(missing_project_names)}' ) + project_ids = [project_name_map[name] for name in project_names] + # this extra filter is needed for the type sytem to be happy + # that there's no Nones in the list + filtered_project_ids = [p for p in project_ids if p is not None] - projects = await self.get_and_check_access_to_projects_for_ids( - user=user, - project_ids=[project_name_map[name] for name in project_names], - readonly=readonly, + return await self.get_and_check_access_to_projects_for_ids( + user, filtered_project_ids, allowed_roles ) - return projects - async def get_and_check_access_to_projects_for_ids( - self, user: str, project_ids: list[ProjectId], readonly: bool + self, + user: str, + project_ids: list[ProjectId], + allowed_roles: list[GroupProjectRole], ) -> list[Project]: """Get project by id""" if not project_ids: @@ -238,83 +265,57 @@ async def get_and_check_access_to_projects_for_ids( ) projects = await self._get_projects_by_ids(project_ids) - # do this all at once to save time - if readonly: - group_id_project_map = {p.read_group_id: p for p in projects} - else: - group_id_project_map = {p.write_group_id: p for p in projects} - group_ids = set(gid for gid in group_id_project_map.keys() if gid) + # Check is any of the provided ids aren't valid project ids + missing_project_ids = set(project_ids) - set(p.id for p in projects) + missing_project_id_strs = [str(p) for p in missing_project_ids] + if missing_project_ids: + raise NotFoundError( + f'Could not find projects with ids {', '.join(missing_project_id_strs)}' + ) - present_group_ids = await self.gtable.check_which_groups_member_has( - group_ids=group_ids, member=user + accessible_projects = await self.get_projects_accessible_by_user( + user, allowed_roles, project_ids ) - missing_group_ids = group_ids - present_group_ids - if missing_group_ids: - # so we can directly return the project names they're missing - missing_project_names = [ - group_id_project_map[gid].name or str(gid) for gid in missing_group_ids - ] - raise NoProjectAccess(missing_project_names, readonly=readonly, author=user) + accessible_project_ids = set(p.id for p in accessible_projects) + missing_project_names = [ + p.name for p in projects if p.id not in accessible_project_ids + ] - return projects + if missing_project_names: + raise NoProjectAccess( + missing_project_names, allowed_roles=allowed_roles, author=user + ) + + return accessible_projects async def check_access_to_project_id( - self, user: str, project_id: ProjectId, readonly: bool, raise_exception=True + self, + user: str, + project_id: ProjectId, + allowed_roles: list[GroupProjectRole], + raise_exception: bool = True, ) -> bool: - """Check whether a user has access to project_id""" - project = await self._get_project_by_id(project_id) - has_access = await self.gtable.check_if_member_in_group( - group_id=project.read_group_id if readonly else project.write_group_id, - member=user, + project = await self.get_and_check_access_to_project_for_id( + user, project_id, allowed_roles, raise_exception ) - if not has_access and raise_exception: - raise NoProjectAccess([project.name], readonly=readonly, author=user) - - return has_access + return project is not None async def check_access_to_project_ids( self, user: str, project_ids: Iterable[ProjectId], - readonly: bool, - raise_exception=True, + allowed_roles: list[GroupProjectRole], ) -> bool: """Check user has access to list of project_ids""" - if not project_ids: - raise Forbidden( - "You don't have access to this resources, as the resource you " - "requested didn't belong to a project" - ) - - projects = await self._get_projects_by_ids(project_ids) - # do this all at once to save time - if readonly: - group_id_project_map = {p.read_group_id: p for p in projects} - else: - group_id_project_map = {p.write_group_id: p for p in projects} - - group_ids = set(gid for gid in group_id_project_map.keys() if gid) - present_group_ids = await self.gtable.check_which_groups_member_has( - group_ids=group_ids, member=user + # This will raise an exception if any of the specified project ids are missing + await self.get_and_check_access_to_projects_for_ids( + user, list(project_ids), allowed_roles ) - missing_group_ids = group_ids - present_group_ids - - if missing_group_ids: - # so we can directly return the project names they're missing - missing_project_names = [ - group_id_project_map[gid].name or str(gid) for gid in missing_group_ids - ] - if raise_exception: - raise NoProjectAccess( - missing_project_names, readonly=readonly, author=user - ) - return False - return True - async def check_project_creator_permissions(self, author): + async def check_project_creator_permissions(self, author: str): """Check author has project_creator permissions""" # check permissions in here is_in_group = await self.gtable.check_if_member_in_group_name( @@ -328,21 +329,6 @@ async def check_project_creator_permissions(self, author): # endregion AUTH - async def get_project_ids_from_names_and_user( - self, user: str, project_names: List[str], readonly: bool - ) -> List[ProjectId]: - """Get project ids from project names and the user""" - if not user: - raise InternalError('An internal error occurred during authorization') - - project_name_map = await self._get_project_name_map() - ordered_project_ids = [project_name_map[name] for name in project_names] - await self.check_access_to_project_ids( - user, ordered_project_ids, readonly=readonly, raise_exception=True - ) - - return ordered_project_ids - # region CREATE / UPDATE async def create_project( @@ -350,23 +336,13 @@ async def create_project( project_name: str, dataset_name: str, author: str, - check_permissions=True, + check_permissions: bool = True, ): """Create project row""" if check_permissions: await self.check_project_creator_permissions(author) async with self.connection.transaction(): - audit_log_id = await self.audit_log_id() - read_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=True), - audit_log_id=audit_log_id, - ) - write_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=False), - audit_log_id=audit_log_id, - ) - _query = """\ INSERT INTO project (name, dataset, audit_log_id, read_group_id, write_group_id) VALUES (:name, :dataset, :audit_log_id, :read_group_id, :write_group_id) @@ -375,8 +351,6 @@ async def create_project( 'name': project_name, 'dataset': dataset_name, 'audit_log_id': await self.audit_log_id(), - 'read_group_id': read_group_id, - 'write_group_id': write_group_id, } project_id = await self.connection.fetch_val(_query, values) @@ -465,18 +439,21 @@ async def delete_project_data( """ values: dict = {'project': project_id} if delete_project: - group_ids = await self.connection.fetch_one( + group_ids_rows = await self.connection.fetch_all( """ - SELECT read_group_id, write_group_id - FROM project WHERE id = :project' + SELECT group_id + FROM project_groups WHERE project_id = :project' """ ) + group_ids = set(r['group_id'] for r in group_ids_rows) + _query += 'DELETE FROM project WHERE id = :project;\n' - _query += 'DELETE FROM `group` WHERE id IN :group_ids\n' - values['group_ids'] = [ - group_ids['read_group_id'], - group_ids['write_group_id'], - ] + if len(group_ids) > 0: + _query += 'DELETE FROM `group` WHERE id IN :group_ids\n' + _query += ( + 'DELETE FROM `project_groups` WHERE project_id = :project\n' + ) + values['group_ids'] = list(group_ids) await self.connection.execute(_query, {'project': project_id}) @@ -519,12 +496,12 @@ async def get_seqr_projects(self) -> list[dict[str, Any]]: class GroupTable: """ - Capture Analysis table operations and queries + Capture Group table operations and queries """ table_name = 'group' - def __init__(self, connection: Database, allow_full_access: bool = None): + def __init__(self, connection: Database, allow_full_access: bool | None = None): if not isinstance(connection, Database): raise ValueError( f'Invalid type connection, expected Database, got {type(connection)}, ' diff --git a/db/python/utils.py b/db/python/utils.py index 1ea11e3d8..a1f5d47d3 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -4,9 +4,10 @@ import os import re from enum import Enum -from typing import Any, Generic, Sequence, TypeVar +from typing import Any, Generic, TypeVar from models.base import SMBase +from models.models.group import GroupProjectRole T = TypeVar('T') @@ -64,21 +65,19 @@ class NoProjectAccess(Forbidden): def __init__( self, - project_names: Sequence[str] | None, + project_names: list[str], author: str, - *args, - readonly: bool | None = None, + allowed_roles: list[GroupProjectRole], + *args: tuple[Any, ...], ): project_names_str = ( ', '.join(repr(p) for p in project_names) if project_names else '' ) - access_type = '' - if readonly is False: - access_type = 'write ' + allowed_roles_str = ' or '.join([r.name for r in allowed_roles]) super().__init__( - f'{author} does not have {access_type}access to resources from the ' - f'following project(s), or they may not exist: {project_names_str}', + f'{author} does not have {allowed_roles_str} access to resources from the ' + f'follgroowing project(s), or they may not exist: {project_names_str}', *args, ) @@ -431,9 +430,8 @@ def get_logger(): return _logger -def to_db_json(val): +def to_db_json(val: Any): """Convert val to json for DB""" - # return psycopg2.extras.Json(val) return json.dumps(val) diff --git a/models/models/group.py b/models/models/group.py new file mode 100644 index 000000000..9d2bb3db7 --- /dev/null +++ b/models/models/group.py @@ -0,0 +1,19 @@ +from enum import Enum + +from models.base import SMBase + +GroupProjectRole = Enum('GroupProjectRole', ['read', 'contribute', 'write']) + + +class Group(SMBase): + """Row for project in 'project' table""" + + id: int + name: str + role: GroupProjectRole + + @staticmethod + def from_db(kwargs): + """From DB row, with db keys""" + kwargs = dict(kwargs) + return Group(**kwargs) diff --git a/models/models/project.py b/models/models/project.py index 9ca19542f..430c66253 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -9,12 +9,10 @@ class Project(SMBase): """Row for project in 'project' table""" - id: Optional[ProjectId] = None - name: Optional[str] = None + id: ProjectId + name: str dataset: Optional[str] = None meta: Optional[dict] = None - read_group_id: Optional[int] = None - write_group_id: Optional[int] = None @staticmethod def from_db(kwargs): diff --git a/scripts/sync_seqr.py b/scripts/sync_seqr.py index 45b415da1..fa685666f 100644 --- a/scripts/sync_seqr.py +++ b/scripts/sync_seqr.py @@ -353,9 +353,11 @@ def _parse_consanguity(consanguity): def process_row(row): return { - seqr_key: key_processor[sm_key](row[sm_key]) - if sm_key in key_processor - else row[sm_key] + seqr_key: ( + key_processor[sm_key](row[sm_key]) + if sm_key in key_processor + else row[sm_key] + ) for seqr_key, sm_key in seqr_map.items() if sm_key in row } @@ -558,7 +560,7 @@ async def _make_update_igv_call(update): if not updates: continue print( - f'{dataset} :: Updating CRAMs {idx * chunk_size + 1} -> {(min((idx + 1 ) * chunk_size, len(all_updates)))} (/{len(all_updates)})' + f'{dataset} :: Updating CRAMs {idx * chunk_size + 1} -> {(min((idx + 1) * chunk_size, len(all_updates)))} (/{len(all_updates)})' ) responses = await asyncio.gather( From e257500c8f796978373088c9637ddc9103558ade Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Fri, 22 Mar 2024 11:56:14 +1100 Subject: [PATCH 02/29] Update project based db connections --- api/routes/analysis.py | 12 ++--- api/routes/assay.py | 4 +- api/routes/family.py | 6 +-- api/routes/participant.py | 10 ++-- api/routes/sample.py | 8 +-- api/routes/sequencing_groups.py | 4 +- api/routes/web.py | 4 +- api/utils/db.py | 95 +++++++++++++++++++++++++++++++-- db/python/connect.py | 20 ++++--- db/python/tables/project.py | 24 +++++---- db/python/utils.py | 2 +- 11 files changed, 142 insertions(+), 47 deletions(-) diff --git a/api/routes/analysis.py b/api/routes/analysis.py index f6306c449..2c6010163 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -11,7 +11,7 @@ from api.utils.dates import parse_date_only_string from api.utils.db import ( Connection, - get_project_readonly_connection, + get_project_read_connection, get_project_write_connection, get_projectless_db_connection, ) @@ -135,7 +135,7 @@ async def update_analysis( ) async def get_all_sample_ids_without_analysis_type( analysis_type: str, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """get_all_sample_ids_without_analysis_type""" atable = AnalysisLayer(connection) @@ -155,7 +155,7 @@ async def get_all_sample_ids_without_analysis_type( operation_id='getIncompleteAnalyses', ) async def get_incomplete_analyses( - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """Get analyses with status queued or in-progress""" atable = AnalysisLayer(connection) @@ -170,7 +170,7 @@ async def get_incomplete_analyses( ) async def get_latest_complete_analysis_for_type( analysis_type: str, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """Get (SINGLE) latest complete analysis for some analysis type""" alayer = AnalysisLayer(connection) @@ -188,7 +188,7 @@ async def get_latest_complete_analysis_for_type( async def get_latest_complete_analysis_for_type_post( analysis_type: str, meta: dict[str, Any] = Body(..., embed=True), # type: ignore[assignment] - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """ Get SINGLE latest complete analysis for some analysis type @@ -274,7 +274,7 @@ async def get_analysis_runner_log( async def get_sample_reads_map( export_type: ExportType = ExportType.JSON, sequencing_types: list[str] = Query(None), # type: ignore - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """ Get map of ExternalSampleId pathToCram InternalSeqGroupID for seqr diff --git a/api/routes/assay.py b/api/routes/assay.py index 0bc4cb48e..46ee1c9c9 100644 --- a/api/routes/assay.py +++ b/api/routes/assay.py @@ -2,7 +2,7 @@ from fastapi import APIRouter -from api.utils import get_project_readonly_connection +from api.utils import get_project_read_connection from api.utils.db import Connection, get_projectless_db_connection from db.python.layers.assay import AssayLayer from db.python.tables.assay import AssayFilter @@ -53,7 +53,7 @@ async def get_assay_by_id( '/{project}/external_id/{external_id}/details', operation_id='getAssayByExternalId' ) async def get_assay_by_external_id( - external_id: str, connection=get_project_readonly_connection + external_id: str, connection=get_project_read_connection ): """Get an assay by ONE of its external identifiers""" assay_layer = AssayLayer(connection) diff --git a/api/routes/family.py b/api/routes/family.py index c7e8a90d5..49a39c83e 100644 --- a/api/routes/family.py +++ b/api/routes/family.py @@ -12,7 +12,7 @@ from api.utils import get_projectless_db_connection from api.utils.db import ( Connection, - get_project_readonly_connection, + get_project_read_connection, get_project_write_connection, ) from api.utils.export import ExportType @@ -75,7 +75,7 @@ async def get_pedigree( replace_with_family_external_ids: bool = True, include_header: bool = True, empty_participant_value: Optional[str] = None, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, include_participants_not_in_families: bool = False, ): """ @@ -142,7 +142,7 @@ async def get_pedigree( async def get_families( participant_ids: Optional[List[int]] = Query(None), sample_ids: Optional[List[str]] = Query(None), - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ) -> List[Family]: """Get families for some project""" family_layer = FamilyLayer(connection) diff --git a/api/routes/participant.py b/api/routes/participant.py index d2b5e23b5..8fd3dde23 100644 --- a/api/routes/participant.py +++ b/api/routes/participant.py @@ -9,7 +9,7 @@ from api.utils import get_projectless_db_connection from api.utils.db import ( Connection, - get_project_readonly_connection, + get_project_read_connection, get_project_write_connection, ) from api.utils.export import ExportType @@ -47,7 +47,7 @@ async def get_individual_metadata_template_for_seqr( external_participant_ids: list[str] | None = Query(default=None), # type: ignore[assignment] # pylint: disable=invalid-name replace_with_participant_external_ids: bool = True, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """Get individual metadata template for SEQR as a CSV""" participant_layer = ParticipantLayer(connection) @@ -90,7 +90,7 @@ async def get_individual_metadata_template_for_seqr( async def get_id_map_by_external_ids( external_participant_ids: list[str], allow_missing: bool = False, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """Get ID map of participants, by external_id""" player = ParticipantLayer(connection) @@ -121,7 +121,7 @@ async def get_external_participant_id_to_sequencing_group_id( sequencing_type: str = None, export_type: ExportType = ExportType.JSON, flip_columns: bool = False, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """ Get csv / tsv export of external_participant_id to sequencing_group_id @@ -214,7 +214,7 @@ class QueryParticipantCriteria(SMBase): ) async def get_participants( criteria: QueryParticipantCriteria, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """Get participants, default ALL participants in project""" player = ParticipantLayer(connection) diff --git a/api/routes/sample.py b/api/routes/sample.py index 982e89798..f95277a8b 100644 --- a/api/routes/sample.py +++ b/api/routes/sample.py @@ -2,7 +2,7 @@ from api.utils.db import ( Connection, - get_project_readonly_connection, + get_project_read_connection, get_project_write_connection, get_projectless_db_connection, ) @@ -62,7 +62,7 @@ async def upsert_samples( async def get_sample_id_map_by_external( external_ids: list[str], allow_missing: bool = False, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """Get map of sample IDs, { [externalId]: internal_sample_id }""" st = SampleLayer(connection) @@ -91,7 +91,7 @@ async def get_sample_id_map_by_internal( '/{project}/id-map/internal/all', operation_id='getAllSampleIdMapByInternal' ) async def get_all_sample_id_map_by_internal( - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ): """Get map of ALL sample IDs, { [internal_id]: external_sample_id }""" st = SampleLayer(connection) @@ -107,7 +107,7 @@ async def get_all_sample_id_map_by_internal( operation_id='getSampleByExternalId', ) async def get_sample_by_external_id( - external_id: str, connection: Connection = get_project_readonly_connection + external_id: str, connection: Connection = get_project_read_connection ): """Get sample by external ID""" st = SampleLayer(connection) diff --git a/api/routes/sequencing_groups.py b/api/routes/sequencing_groups.py index 9a35ee1e4..d9842c5b1 100644 --- a/api/routes/sequencing_groups.py +++ b/api/routes/sequencing_groups.py @@ -5,7 +5,7 @@ from api.utils.db import ( Connection, - get_project_readonly_connection, + get_project_read_connection, get_project_write_connection, get_projectless_db_connection, ) @@ -42,7 +42,7 @@ async def get_sequencing_group( @router.get('/project/{project}', operation_id='getAllSequencingGroupIdsBySampleByType') async def get_all_sequencing_group_ids_by_sample_by_type( - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ) -> dict[str, dict[str, list[str]]]: """Creates a new sample, and returns the internal sample ID""" st = SequencingGroupLayer(connection) diff --git a/api/routes/web.py b/api/routes/web.py index 71ab77292..5edfb5010 100644 --- a/api/routes/web.py +++ b/api/routes/web.py @@ -10,7 +10,7 @@ from api.utils.db import ( Connection, - get_project_readonly_connection, + get_project_read_connection, get_project_write_connection, get_projectless_db_connection, ) @@ -42,7 +42,7 @@ async def get_project_summary( grid_filter: list[SearchItem], limit: int = 20, token: Optional[int] = 0, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ) -> ProjectSummary: """Creates a new sample, and returns the internal sample ID""" st = WebLayer(connection) diff --git a/api/utils/db.py b/api/utils/db.py index 608a49888..10a1ca0e4 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -12,6 +12,7 @@ from db.python.connect import Connection, SMConnections from db.python.gcp_connect import BqConnection, PubSubConnection from db.python.tables.project import ProjectPermissionsTable +from models.models.group import GroupProjectRole EXPECTED_AUDIENCE = getenv('SM_OAUTHAUDIENCE') @@ -96,7 +97,7 @@ async def dependable_get_write_project_connection( return await ProjectPermissionsTable.get_project_connection( project_name=project, author=author, - readonly=False, + allowed_roles={GroupProjectRole.write}, ar_guid=ar_guid, on_behalf_of=on_behalf_of, meta=meta, @@ -132,21 +133,102 @@ async def HACK_dependable_contributor_project_connection( return connection -async def dependable_get_readonly_project_connection( +async def HACK_dependable_contributor_project_connection( project: str, + request: Request, author: str = Depends(authenticate), ar_guid: str = Depends(get_ar_guid), extra_values: dict | None = Depends(get_extra_audit_log_values), ) -> Connection: """FastAPI handler for getting connection WITH project""" - meta = {} + meta = {"path": request.url.path} + if request.client: + meta["ip"] = request.client.host + if extra_values: meta.update(extra_values) + meta['role'] = 'contributor' + + # hack by making it appear readonly + connection = await ProjectPermissionsTable.get_project_connection( + project_name=project, + author=author, + allowed_roles={ + GroupProjectRole.read, + GroupProjectRole.write, + GroupProjectRole.contribute, + }, + on_behalf_of=None, + ar_guid=ar_guid, + meta=meta, + ) + + +async def dependable_get_contribute_project_connection( + project: str, + request: Request, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), + extra_values: dict | None = Depends(get_extra_audit_log_values), +) -> Connection: + """FastAPI handler for getting connection WITH project""" return await ProjectPermissionsTable.get_project_connection( project_name=project, author=author, - readonly=True, + allowed_roles={GroupProjectRole.write, GroupProjectRole.contribute}, + on_behalf_of=None, + ar_guid=ar_guid, + meta=meta, + ) + + # then hack it so + connection.readonly = False + + return connection + + +async def dependable_get_read_project_connection( + project: str, + request: Request, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), + extra_values: dict | None = Depends(get_extra_audit_log_values), +) -> Connection: + """FastAPI handler for getting connection WITH project""" + meta = {"path": request.url.path} + if request.client: + meta["ip"] = request.client.host + + if extra_values: + meta.update(extra_values) + + return await ProjectPermissionsTable.get_project_connection( + project_name=project, + author=author, + allowed_roles={ + GroupProjectRole.read, + GroupProjectRole.write, + GroupProjectRole.contribute, + }, + on_behalf_of=None, + ar_guid=ar_guid, + meta=meta, + ) + + +async def dependable_get_contribute_project_connection( + project: str, + request: Request, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), + extra_values: dict | None = Depends(get_extra_audit_log_values), +) -> Connection: + """FastAPI handler for getting connection WITH project""" + return await ProjectPermissionsTable.get_project_connection( + project_name=project, + author=author, + allowed_roles={GroupProjectRole.write, GroupProjectRole.contribute}, on_behalf_of=None, ar_guid=ar_guid, meta=meta, @@ -207,7 +289,10 @@ def validate_iap_jwt_and_get_email(iap_jwt, audience): get_author = Depends(authenticate) -get_project_readonly_connection = Depends(dependable_get_readonly_project_connection) +get_project_read_connection = Depends(dependable_get_read_project_connection) +get_project_contribute_connection = Depends( + dependable_get_contribute_project_connection +) HACK_get_project_contributor_connection = Depends( HACK_dependable_contributor_project_connection ) diff --git a/db/python/connect.py b/db/python/connect.py index 7bedf571e..14c40da3e 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -3,6 +3,7 @@ """ Code for connecting to Postgres database """ + import abc import asyncio import json @@ -13,6 +14,7 @@ from api.settings import LOG_DATABASE_QUERIES from db.python.utils import InternalError +from models.models.group import GroupProjectRole logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -51,7 +53,7 @@ def __init__( project: int | None, author: str, on_behalf_of: str | None, - readonly: bool, + allowed_roles: set[GroupProjectRole], ar_guid: str | None, meta: dict[str, str] | None = None, ): @@ -59,7 +61,7 @@ def __init__( self.project: int | None = project self.author: str = author self.on_behalf_of: str | None = on_behalf_of - self.readonly: bool = readonly + self.allowed_roles: set[GroupProjectRole] = allowed_roles self.ar_guid: str | None = ar_guid self.meta = meta @@ -68,7 +70,11 @@ def __init__( async def audit_log_id(self): """Get audit_log ID for write operations, cached per connection""" - if self.readonly: + # If connection doesn't have a writeable role, don't allow getting an audit log id + if not self.allowed_roles & { + GroupProjectRole.write, + GroupProjectRole.contribute, + }: raise InternalError( 'Trying to get a audit_log ID, but not a write connection' ) @@ -114,7 +120,7 @@ def get_connection_string(self): class ConnectionStringDatabaseConfiguration(DatabaseConfiguration): """Database Configuration that takes a literal DatabaseConfiguration""" - def __init__(self, connection_string): + def __init__(self, connection_string: str): self.connection_string = connection_string def get_connection_string(self): @@ -240,8 +246,8 @@ async def get_connection_no_project( conn = await SMConnections.get_made_connection() - # we don't authenticate project-less connection, but rely on the - # the endpoint to validate the resources + # all roles are allowed here as we don't authenticate project-less connection, + # but rely on the the endpoint to validate the resources return Connection( connection=conn, @@ -249,6 +255,6 @@ async def get_connection_no_project( project=None, on_behalf_of=None, ar_guid=ar_guid, - readonly=False, + allowed_roles={r for r in GroupProjectRole}, meta=meta, ) diff --git a/db/python/tables/project.py b/db/python/tables/project.py index ad8e14811..9a92e1985 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -57,7 +57,7 @@ async def get_project_connection( *, author: str, project_name: str, - readonly: bool, + allowed_roles: set[GroupProjectRole], ar_guid: str, on_behalf_of: str | None = None, meta: dict[str, str] | None = None, @@ -70,14 +70,18 @@ async def get_project_connection( pt = ProjectPermissionsTable(connection=None, database_connection=conn) project = await pt.get_and_check_access_to_project_for_name( - user=author, project_name=project_name, readonly=readonly + user=author, project_name=project_name, allowed_roles=allowed_roles ) + # python types doesn't know this can't be none due + # to the default raise_exception of true + assert project + return Connection( connection=conn, author=author, project=project.id, - readonly=readonly, + allowed_roles=allowed_roles, on_behalf_of=on_behalf_of, ar_guid=ar_guid, meta=meta, @@ -158,7 +162,7 @@ async def get_all_projects(self, author: str): async def get_projects_accessible_by_user( self, user: str, - allowed_roles: list[GroupProjectRole], + allowed_roles: set[GroupProjectRole], project_id_filter: list[int] | None, ) -> list[Project]: """ @@ -198,7 +202,7 @@ async def get_and_check_access_to_project_for_id( self, user: str, project_id: ProjectId, - allowed_roles: list[GroupProjectRole], + allowed_roles: set[GroupProjectRole], raise_exception: bool = True, ) -> Project | None: """Get project by id""" @@ -221,7 +225,7 @@ async def get_and_check_access_to_project_for_name( self, user: str, project_name: str, - allowed_roles: list[GroupProjectRole], + allowed_roles: set[GroupProjectRole], raise_exception: bool = True, ) -> Project | None: """Get project by name + perform access checks""" @@ -231,7 +235,7 @@ async def get_and_check_access_to_project_for_name( ) async def get_and_check_access_to_projects_for_names( - self, user: str, project_names: list[str], allowed_roles: list[GroupProjectRole] + self, user: str, project_names: list[str], allowed_roles: set[GroupProjectRole] ): """Get projects by names + perform access checks""" project_name_map = await self._get_project_name_map() @@ -255,7 +259,7 @@ async def get_and_check_access_to_projects_for_ids( self, user: str, project_ids: list[ProjectId], - allowed_roles: list[GroupProjectRole], + allowed_roles: set[GroupProjectRole], ) -> list[Project]: """Get project by id""" if not project_ids: @@ -294,7 +298,7 @@ async def check_access_to_project_id( self, user: str, project_id: ProjectId, - allowed_roles: list[GroupProjectRole], + allowed_roles: set[GroupProjectRole], raise_exception: bool = True, ) -> bool: project = await self.get_and_check_access_to_project_for_id( @@ -306,7 +310,7 @@ async def check_access_to_project_ids( self, user: str, project_ids: Iterable[ProjectId], - allowed_roles: list[GroupProjectRole], + allowed_roles: set[GroupProjectRole], ) -> bool: """Check user has access to list of project_ids""" # This will raise an exception if any of the specified project ids are missing diff --git a/db/python/utils.py b/db/python/utils.py index a1f5d47d3..14f655108 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -67,7 +67,7 @@ def __init__( self, project_names: list[str], author: str, - allowed_roles: list[GroupProjectRole], + allowed_roles: set[GroupProjectRole], *args: tuple[Any, ...], ): project_names_str = ( From ea7c3fd54adc1822ed19b21c56189f00da275bb3 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Fri, 22 Mar 2024 15:49:26 +1100 Subject: [PATCH 03/29] update usage of project and group permission functions --- api/graphql/loaders.py | 6 ++++-- api/graphql/schema.py | 17 +++++++++++----- api/routes/analysis.py | 15 +++++++++----- api/routes/assay.py | 13 +++++++++---- api/routes/family.py | 2 +- api/routes/participant.py | 2 +- api/routes/project.py | 10 ++++++---- api/routes/sample.py | 11 +++++++++-- api/routes/web.py | 3 ++- api/server.py | 2 +- api/utils/__init__.py | 9 +-------- api/utils/db.py | 7 +++++-- db/python/connect.py | 2 +- db/python/layers/analysis.py | 17 ++++++++++------ db/python/layers/assay.py | 17 ++++++++-------- db/python/layers/audit_log.py | 5 +++-- db/python/layers/family.py | 9 +++++---- db/python/layers/participant.py | 15 +++++++++----- db/python/layers/sample.py | 29 ++++++++++++++++------------ db/python/layers/seqr.py | 11 +++++++---- db/python/layers/sequencing_group.py | 9 +++++---- db/python/layers/web.py | 3 ++- db/python/tables/project.py | 5 +++-- models/models/group.py | 10 ++++++++++ test/test_project_groups.py | 26 +++++++++++++++---------- test/testbase.py | 5 +++-- 26 files changed, 163 insertions(+), 97 deletions(-) diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 2b4980a27..87889ae84 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -9,7 +9,8 @@ from fastapi import Request from strawberry.dataloader import DataLoader -from api.utils import get_projectless_db_connection, group_by +from api.utils import group_by +from api.utils.db import get_projectless_db_connection from db.python.layers import ( AnalysisLayer, AssayLayer, @@ -38,6 +39,7 @@ ) from models.models.audit_log import AuditLogInternal from models.models.family import PedRowInternal +from models.models.group import ReadAccessRoles class LoaderKeys(enum.Enum): @@ -346,7 +348,7 @@ async def load_projects_for_ids(project_ids: list[int], connection) -> list[Proj """ pttable = ProjectPermissionsTable(connection) projects = await pttable.get_and_check_access_to_projects_for_ids( - user=connection.author, project_ids=project_ids, readonly=True + user=connection.user, project_ids=project_ids, allowed_roles=ReadAccessRoles ) p_by_id = {p.id: p for p in projects} diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 7352c8cd5..259481bfa 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -53,6 +53,7 @@ ) from models.models.analysis_runner import AnalysisRunnerInternal from models.models.family import PedRowInternal +from models.models.group import ReadAccessRoles from models.models.project import ProjectId from models.models.sample import sample_id_transform_to_raw from models.utils.cohort_id_format import cohort_id_format, cohort_id_transform_to_raw @@ -824,7 +825,9 @@ async def analyses( ptable = ProjectPermissionsTable(connection) project_ids = project.all_values() projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_ids, readonly=True + user=connection.author, + project_names=project_ids, + allowed_roles=ReadAccessRoles, ) project_id_map: dict[str, int] = { p.name: p.id for p in projects if p.name and p.id @@ -1053,7 +1056,7 @@ async def project(self, info: Info, name: str) -> GraphQLProject: connection = info.context['connection'] ptable = ProjectPermissionsTable(connection) project = await ptable.get_and_check_access_to_project_for_name( - user=connection.author, project_name=name, readonly=True + user=connection.author, project_name=name, allowed_roles=ReadAccessRoles ) return GraphQLProject.from_internal(project) @@ -1083,7 +1086,9 @@ async def sample( if project: project_names = project.all_values() projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + user=connection.author, + project_names=project_names, + allowed_roles=ReadAccessRoles, ) project_name_map = {p.name: p.id for p in projects if p.name and p.id} @@ -1135,7 +1140,9 @@ async def sequencing_groups( if project: project_names = project.all_values() projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + user=connection.author, + project_names=project_names, + allowed_roles=ReadAccessRoles, ) project_id_map = {p.name: p.id for p in projects if p.name and p.id} _project_filter = project.to_internal_filter_mapped( @@ -1193,7 +1200,7 @@ async def my_projects(self, info: Info) -> list[GraphQLProject]: connection = info.context['connection'] ptable = ProjectPermissionsTable(connection) projects = await ptable.get_projects_accessible_by_user( - connection.author, readonly=True + connection.author, allowed_roles=ReadAccessRoles ) return [GraphQLProject.from_internal(p) for p in projects] diff --git a/api/routes/analysis.py b/api/routes/analysis.py index 2c6010163..75e3af4a3 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -26,6 +26,7 @@ AnalysisInternal, ProportionalDateTemporalMethod, ) +from models.models.group import ReadAccessRoles from models.utils.sequencing_group_id_format import ( sequencing_group_id_format, sequencing_group_id_format_list, @@ -230,7 +231,9 @@ async def query_analyses( pt = ProjectPermissionsTable(connection) projects = await pt.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=query.projects, readonly=True + user=connection.author, + project_names=query.projects, + allowed_roles=ReadAccessRoles, ) project_name_map = {p.name: p.id for p in projects} atable = AnalysisLayer(connection) @@ -253,9 +256,10 @@ async def get_analysis_runner_log( project_ids = None if project_names: pt = ProjectPermissionsTable(connection) - project_ids = await pt.get_project_ids_from_names_and_user( - connection.author, project_names, readonly=True + projects = await pt.get_and_check_access_to_projects_for_names( + connection.author, project_names, allowed_roles=ReadAccessRoles ) + project_ids = [p.id for p in projects] results = await atable.get_analysis_runner_log( project_ids=project_ids, @@ -342,9 +346,10 @@ async def get_proportionate_map( } """ pt = ProjectPermissionsTable(connection) - project_ids = await pt.get_project_ids_from_names_and_user( - connection.author, projects, readonly=True + project_list = await pt.get_and_check_access_to_projects_for_names( + connection.author, projects, allowed_roles=ReadAccessRoles ) + project_ids = [p.id for p in project_list] start_date = parse_date_only_string(start) if start else None end_date = parse_date_only_string(end) if end else None diff --git a/api/routes/assay.py b/api/routes/assay.py index 46ee1c9c9..dcf584f54 100644 --- a/api/routes/assay.py +++ b/api/routes/assay.py @@ -2,14 +2,18 @@ from fastapi import APIRouter -from api.utils import get_project_read_connection -from api.utils.db import Connection, get_projectless_db_connection +from api.utils.db import ( + Connection, + get_project_read_connection, + get_projectless_db_connection, +) from db.python.layers.assay import AssayLayer from db.python.tables.assay import AssayFilter from db.python.tables.project import ProjectPermissionsTable from db.python.utils import GenericFilter from models.base import SMBase from models.models.assay import AssayUpsert +from models.models.group import ReadAccessRoles from models.utils.sample_id_format import sample_id_transform_to_raw_list router = APIRouter(prefix='/assay', tags=['assay']) @@ -84,9 +88,10 @@ async def get_assays_by_criteria( pids: list[int] | None = None if criteria.projects: - pids = await pt.get_project_ids_from_names_and_user( - connection.author, criteria.projects, readonly=True + project_list = await pt.get_and_check_access_to_projects_for_names( + connection.author, criteria.projects, allowed_roles=ReadAccessRoles ) + pids = [p.id for p in project_list] unwrapped_sample_ids: list[int] | None = None if criteria.sample_ids: diff --git a/api/routes/family.py b/api/routes/family.py index 49a39c83e..aa7827090 100644 --- a/api/routes/family.py +++ b/api/routes/family.py @@ -9,11 +9,11 @@ from pydantic import BaseModel from starlette.responses import StreamingResponse -from api.utils import get_projectless_db_connection from api.utils.db import ( Connection, get_project_read_connection, get_project_write_connection, + get_projectless_db_connection, ) from api.utils.export import ExportType from api.utils.extensions import guess_delimiter_by_upload_file_obj diff --git a/api/routes/participant.py b/api/routes/participant.py index 8fd3dde23..4a9409f34 100644 --- a/api/routes/participant.py +++ b/api/routes/participant.py @@ -6,11 +6,11 @@ from fastapi.params import Query from starlette.responses import JSONResponse, StreamingResponse -from api.utils import get_projectless_db_connection from api.utils.db import ( Connection, get_project_read_connection, get_project_write_connection, + get_projectless_db_connection, ) from api.utils.export import ExportType from db.python.layers.participant import ParticipantLayer diff --git a/api/routes/project.py b/api/routes/project.py index 559938485..e375066a3 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -4,6 +4,7 @@ from api.utils.db import Connection, get_projectless_db_connection from db.python.tables.project import ProjectPermissionsTable +from models.models.group import FullWriteAccessRoles, GroupProjectRole, ReadAccessRoles from models.models.project import Project router = APIRouter(prefix='/project', tags=['project']) @@ -21,7 +22,7 @@ async def get_my_projects(connection=get_projectless_db_connection): """Get projects I have access to""" ptable = ProjectPermissionsTable(connection) projects = await ptable.get_projects_accessible_by_user( - author=connection.author, readonly=True + user=connection.author, allowed_roles=ReadAccessRoles ) return [p.name for p in projects] @@ -86,7 +87,7 @@ async def delete_project_data( """ ptable = ProjectPermissionsTable(connection) p_obj = await ptable.get_and_check_access_to_project_for_name( - user=connection.author, project_name=project, readonly=False + user=connection.author, project_name=project, allowed_roles=FullWriteAccessRoles ) success = await ptable.delete_project_data( project_id=p_obj.id, delete_project=delete_project, author=connection.author @@ -99,7 +100,8 @@ async def delete_project_data( async def update_project_members( project: str, members: list[str], - readonly: bool, + # @TODO change this to accept a role + role: str, connection: Connection = get_projectless_db_connection, ): """ @@ -108,7 +110,7 @@ async def update_project_members( """ ptable = ProjectPermissionsTable(connection) await ptable.set_group_members( - group_name=ptable.get_project_group_name(project, readonly=readonly), + group_name=ptable.get_project_group_name(project, role=GroupProjectRole(role)), members=members, author=connection.author, ) diff --git a/api/routes/sample.py b/api/routes/sample.py index f95277a8b..8919c5d52 100644 --- a/api/routes/sample.py +++ b/api/routes/sample.py @@ -9,6 +9,7 @@ from db.python.layers.sample import SampleLayer from db.python.tables.project import ProjectPermissionsTable from models.base import SMBase +from models.models.group import ReadAccessRoles from models.models.sample import SampleUpsert from models.utils.sample_id_format import ( # Sample, sample_id_format, @@ -129,6 +130,11 @@ class GetSamplesCriteria(SMBase): @router.post('/', operation_id='getSamples') async def get_samples( criteria: GetSamplesCriteria, + meta: dict = None, + participant_ids: list[int] = None, + # project_ids is inaccurately named, it should be `project_names` + project_ids: list[str] = None, + active: bool = Body(default=True), connection: Connection = get_projectless_db_connection, ): """ @@ -139,9 +145,10 @@ async def get_samples( pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if criteria.project_ids: - pids = await pt.get_project_ids_from_names_and_user( - connection.author, criteria.project_ids, readonly=True + projects = await pt.get_and_check_access_to_projects_for_names( + connection.author, criteria.project_ids, allowed_roles=ReadAccessRoles ) + pids = [p.id for p in projects] sample_ids_raw = ( sample_id_transform_to_raw_list(criteria.sample_ids) diff --git a/api/routes/web.py b/api/routes/web.py index 5edfb5010..e52a5620e 100644 --- a/api/routes/web.py +++ b/api/routes/web.py @@ -19,6 +19,7 @@ from db.python.layers.web import SearchItem, WebLayer from db.python.tables.project import ProjectPermissionsTable from models.enums.web import SeqrDatasetType +from models.models.group import ReadAccessRoles from models.models.search import SearchResponse from models.models.web import PagingLinks, ProjectSummary @@ -82,7 +83,7 @@ async def search_by_keyword(keyword: str, connection=get_projectless_db_connecti # raise ValueError("Test") pt = ProjectPermissionsTable(connection) projects = await pt.get_projects_accessible_by_user( - connection.author, readonly=True + connection.author, allowed_roles=ReadAccessRoles ) pmap = {p.id: p for p in projects} responses = await SearchLayer(connection).search( diff --git a/api/server.py b/api/server.py index 0e355bcfd..9a57ff1cb 100644 --- a/api/server.py +++ b/api/server.py @@ -13,8 +13,8 @@ from api import routes from api.graphql.schema import MetamistGraphQLRouter # type: ignore from api.settings import PROFILE_REQUESTS, SKIP_DATABASE_CONNECTION -from api.utils import get_openapi_schema_func from api.utils.exceptions import determine_code_from_error +from api.utils.openapi import get_openapi_schema_func from db.python.connect import SMConnections from db.python.tables.project import is_all_access from db.python.utils import get_logger diff --git a/api/utils/__init__.py b/api/utils/__init__.py index fe5ae4cdd..089785938 100644 --- a/api/utils/__init__.py +++ b/api/utils/__init__.py @@ -1,15 +1,8 @@ """Importing GCP libraries""" + from collections import defaultdict from typing import Callable, Iterable, TypeVar -from .db import ( - authenticate, - get_project_readonly_connection, - get_project_write_connection, - get_projectless_db_connection, -) -from .openapi import get_openapi_schema_func - T = TypeVar('T') X = TypeVar('X') diff --git a/api/utils/db.py b/api/utils/db.py index 10a1ca0e4..a850f8efb 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -196,9 +196,9 @@ async def dependable_get_read_project_connection( extra_values: dict | None = Depends(get_extra_audit_log_values), ) -> Connection: """FastAPI handler for getting connection WITH project""" - meta = {"path": request.url.path} + meta = {'path': request.url.path} if request.client: - meta["ip"] = request.client.host + meta['ip'] = request.client.host if extra_values: meta.update(extra_values) @@ -225,6 +225,9 @@ async def dependable_get_contribute_project_connection( extra_values: dict | None = Depends(get_extra_audit_log_values), ) -> Connection: """FastAPI handler for getting connection WITH project""" + meta = {"path": request.url.path} + if request.client: + meta["ip"] = request.client.host return await ProjectPermissionsTable.get_project_connection( project_name=project, author=author, diff --git a/db/python/connect.py b/db/python/connect.py index 14c40da3e..2751517a0 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -255,6 +255,6 @@ async def get_connection_no_project( project=None, on_behalf_of=None, ar_guid=ar_guid, - allowed_roles={r for r in GroupProjectRole}, + allowed_roles=set(GroupProjectRole), meta=meta, ) diff --git a/db/python/layers/analysis.py b/db/python/layers/analysis.py index 5057ce2dc..db469468b 100644 --- a/db/python/layers/analysis.py +++ b/db/python/layers/analysis.py @@ -20,6 +20,7 @@ ProportionalDateTemporalMethod, SequencingGroupInternal, ) +from models.models.group import FullWriteAccessRoles, ReadAccessRoles from models.models.project import ProjectId from models.models.sequencing_group import SequencingGroupInternalId @@ -58,7 +59,7 @@ async def get_analysis_by_id(self, analysis_id: int, check_project_id=True): project, analysis = await self.at.get_analysis_by_id(analysis_id) if check_project_id: await self.ptable.check_access_to_project_id( - self.author, project, readonly=True + self.author, project, allowed_roles=ReadAccessRoles ) return analysis @@ -114,7 +115,9 @@ async def get_sample_cram_path_map_for_seqr( participant_ids=participant_ids, ) - async def query(self, filter_: AnalysisFilter, check_project_ids=True): + async def query( + self, filter_: AnalysisFilter, check_project_ids: bool = True + ) -> list[AnalysisInternal]: """Query analyses""" analyses = await self.at.query(filter_) @@ -123,7 +126,9 @@ async def query(self, filter_: AnalysisFilter, check_project_ids=True): if check_project_ids and not filter_.project: await self.ptable.check_access_to_project_ids( - self.author, set(a.project for a in analyses), readonly=True + self.author, + set(a.project for a in analyses if a.project is not None), + allowed_roles=ReadAccessRoles, ) return analyses @@ -156,7 +161,7 @@ async def get_cram_size_proportionate_map( raise ValueError(f'start_date ({start_date}) must be after 2020-01-01') project_objs = await self.ptable.get_and_check_access_to_projects_for_ids( - project_ids=projects, user=self.author, readonly=True + project_ids=projects, user=self.author, allowed_roles=ReadAccessRoles ) project_name_map = {p.id: p.name for p in project_objs} @@ -559,7 +564,7 @@ async def add_sequencing_groups_to_analysis( if check_project_id: project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + self.author, project_ids, allowed_roles=FullWriteAccessRoles ) return await self.at.add_sequencing_groups_to_analysis( @@ -580,7 +585,7 @@ async def update_analysis( if check_project_id: project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + self.author, project_ids, allowed_roles=FullWriteAccessRoles ) await self.at.update_analysis( diff --git a/db/python/layers/assay.py b/db/python/layers/assay.py index 3328acd82..9d6a04c9d 100644 --- a/db/python/layers/assay.py +++ b/db/python/layers/assay.py @@ -6,6 +6,7 @@ from db.python.tables.sample import SampleTable from db.python.utils import NoOpAenter from models.models.assay import AssayInternal, AssayUpsertInternal +from models.models.group import FullWriteAccessRoles, ReadAccessRoles class AssayLayer(BaseLayer): @@ -27,7 +28,7 @@ async def query(self, filter_: AssayFilter = None, check_project_id=True): if check_project_id: await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=True + user=self.author, project_ids=projects, allowed_roles=ReadAccessRoles ) return assays @@ -40,7 +41,7 @@ async def get_assay_by_id( if check_project_id: await self.ptable.check_access_to_project_id( - self.author, project, readonly=True + self.author, project, allowed_roles=ReadAccessRoles ) return assay @@ -59,7 +60,7 @@ async def get_assay_by_external_id( return assay async def get_assays_for_sequencing_group_ids( - self, sequencing_group_ids: list[int], check_project_ids=True + self, sequencing_group_ids: list[int], check_project_ids: bool = True ) -> dict[int, list[AssayInternal]]: """Get assays for a list of sequencing group IDs""" projects, assays = await self.seqt.get_assays_for_sequencing_group_ids( @@ -71,7 +72,7 @@ async def get_assays_for_sequencing_group_ids( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) return assays @@ -108,7 +109,7 @@ async def get_assays_by( # if we didn't specify a project, we need to check access # to the projects we got back await self.ptable.check_access_to_project_ids( - self.author, projs, readonly=True + self.author, projs, allowed_roles=ReadAccessRoles ) return seqs @@ -128,7 +129,7 @@ async def upsert_assay( [assay.sample_id] ) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + self.author, project_ids, allowed_roles=FullWriteAccessRoles ) seq_id = await self.seqt.insert_assay( @@ -144,7 +145,7 @@ async def upsert_assay( # can check the project id of the assay we're updating project_ids = await self.seqt.get_projects_by_assay_ids([assay.id]) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + self.author, project_ids, allowed_roles=FullWriteAccessRoles ) # Otherwise update await self.seqt.update_assay( @@ -170,7 +171,7 @@ async def upsert_assays( st = SampleTable(self.connection) project_ids = await st.get_project_ids_for_sample_ids(list(sample_ids)) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + self.author, project_ids, allowed_roles=FullWriteAccessRoles ) with_function = ( diff --git a/db/python/layers/audit_log.py b/db/python/layers/audit_log.py index 1cc351ecc..406408b1c 100644 --- a/db/python/layers/audit_log.py +++ b/db/python/layers/audit_log.py @@ -1,6 +1,7 @@ from db.python.layers.base import BaseLayer, Connection from db.python.tables.audit_log import AuditLogTable from models.models.audit_log import AuditLogId, AuditLogInternal +from models.models.group import ReadAccessRoles class AuditLogLayer(BaseLayer): @@ -12,7 +13,7 @@ def __init__(self, connection: Connection): # GET async def get_for_ids( - self, ids: list[AuditLogId], check_project_id=True + self, ids: list[AuditLogId], check_project_id: bool = True ) -> list[AuditLogInternal]: """Query for samples""" if not ids: @@ -22,7 +23,7 @@ async def get_for_ids( if check_project_id: projects = {log.auth_project for log in logs} await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=True + user=self.author, project_ids=projects, allowed_roles=ReadAccessRoles ) return logs diff --git a/db/python/layers/family.py b/db/python/layers/family.py index d861f9d32..56fb12942 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -13,6 +13,7 @@ from db.python.tables.sample import SampleTable from db.python.utils import GenericFilter, NotFoundError from models.models.family import FamilyInternal, PedRow, PedRowInternal +from models.models.group import FullWriteAccessRoles, ReadAccessRoles from models.models.participant import ParticipantUpsertInternal from models.models.project import ProjectId @@ -48,7 +49,7 @@ async def get_family_by_internal_id( family = families[0] if check_project_id: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, [project], allowed_roles=ReadAccessRoles ) return family @@ -101,7 +102,7 @@ async def get_families_by_ids( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True + self.connection.author, projects, allowed_roles=ReadAccessRoles ) if check_missing and len(family_ids) != len(families): @@ -124,7 +125,7 @@ async def get_families_by_participants( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True + self.connection.author, projects, allowed_roles=ReadAccessRoles ) return participant_map @@ -141,7 +142,7 @@ async def update_family( if check_project_ids: project_ids = await self.ftable.get_projects_by_family_ids([id_]) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + self.author, project_ids, allowed_roles=FullWriteAccessRoles ) return await self.ftable.update_family( diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index 6b1bffd71..a4f529723 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -21,6 +21,7 @@ split_generic_terms, ) from models.models.family import PedRowInternal +from models.models.group import FullWriteAccessRoles, ReadAccessRoles from models.models.participant import ParticipantInternal, ParticipantUpsertInternal from models.models.project import ProjectId @@ -258,7 +259,7 @@ async def get_participants_by_ids( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) if not allow_missing and len(participants) != len(pids): @@ -532,7 +533,7 @@ async def get_participants_by_families( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True + self.connection.author, projects, allowed_roles=ReadAccessRoles ) return family_map @@ -601,7 +602,9 @@ async def upsert_participant( ) await self.ptable.check_access_to_project_ids( - self.connection.author, project_ids, readonly=False + self.connection.author, + project_ids, + allowed_roles=FullWriteAccessRoles, ) await self.pttable.update_participant( participant_id=participant.id, @@ -664,7 +667,9 @@ async def update_many_participant_external_ids( list(internal_to_external_id.keys()) ) await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=False + user=self.author, + project_ids=projects, + allowed_roles=FullWriteAccessRoles, ) return await self.pttable.update_many_participant_external_ids( @@ -950,7 +955,7 @@ async def check_project_access_for_participants_families( return await self.ptable.check_access_to_project_ids( self.connection.author, list(pprojects | fprojects), - readonly=True, + allowed_roles=ReadAccessRoles, ) async def update_participant_family( diff --git a/db/python/layers/sample.py b/db/python/layers/sample.py index 45858db1f..7d805284c 100644 --- a/db/python/layers/sample.py +++ b/db/python/layers/sample.py @@ -9,6 +9,7 @@ from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter, SampleTable from db.python.utils import GenericFilter, NotFoundError +from models.models.group import FullWriteAccessRoles, ReadAccessRoles from models.models.project import ProjectId from models.models.sample import SampleInternal, SampleUpsertInternal from models.utils.sample_id_format import sample_id_format_list @@ -29,7 +30,7 @@ async def get_by_id(self, sample_id: int, check_project_id=True) -> SampleIntern project, sample = await self.st.get_sample_by_id(sample_id) if check_project_id: await self.pt.check_access_to_project_ids( - self.connection.author, [project], readonly=True + self.connection.author, [project], allowed_roles=ReadAccessRoles ) return sample @@ -44,7 +45,7 @@ async def query( if check_project_ids: await self.pt.check_access_to_project_ids( - self.connection.author, projects, readonly=True + self.connection.author, projects, allowed_roles=ReadAccessRoles ) return samples @@ -65,7 +66,7 @@ async def get_samples_by_participants( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) grouped_samples = group_by(samples, lambda s: s.participant_id) @@ -83,7 +84,7 @@ async def get_sample_by_id( project, sample = await self.st.get_sample_by_id(sample_id) if check_project_id: await self.pt.check_access_to_project_ids( - self.author, [project], readonly=True + self.author, [project], allowed_roles=ReadAccessRoles ) return sample @@ -148,7 +149,7 @@ async def get_internal_to_external_sample_id_map( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) return sample_id_map @@ -177,7 +178,7 @@ async def get_samples_by( # so no else required pjcts = await self.st.get_project_ids_for_sample_ids(sample_ids) await self.ptable.check_access_to_project_ids( - self.author, pjcts, readonly=True + self.author, pjcts, allowed_roles=ReadAccessRoles ) _returned_project_ids, samples = await self.st.query( @@ -194,7 +195,7 @@ async def get_samples_by( if not project_ids and check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, _returned_project_ids, readonly=True + self.author, _returned_project_ids, allowed_roles=ReadAccessRoles ) return samples @@ -211,7 +212,9 @@ async def get_samples_create_date( ) -> dict[int, datetime.date]: """Get a map of {internal_sample_id: date_created} for list of sample_ids""" pjcts = await self.st.get_project_ids_for_sample_ids(sample_ids) - await self.pt.check_access_to_project_ids(self.author, pjcts, readonly=True) + await self.pt.check_access_to_project_ids( + self.author, pjcts, allowed_roles=ReadAccessRoles + ) return await self.st.get_samples_create_date(sample_ids) # CREATE / UPDATES @@ -287,7 +290,7 @@ async def upsert_samples( if sids: pjcts = await self.st.get_project_ids_for_sample_ids(sids) await self.ptable.check_access_to_project_ids( - self.author, pjcts, readonly=False + self.author, pjcts, allowed_roles=FullWriteAccessRoles ) async with with_function(): @@ -329,7 +332,9 @@ async def merge_samples( if check_project_id: projects = await self.st.get_project_ids_for_sample_ids([id_keep, id_merge]) await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=False + user=self.author, + project_ids=projects, + allowed_roles=FullWriteAccessRoles, ) return await self.st.merge_samples( @@ -351,7 +356,7 @@ async def update_many_participant_ids( if check_sample_ids: project_ids = await self.st.get_project_ids_for_sample_ids(ids) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + self.author, project_ids, allowed_roles=FullWriteAccessRoles ) await self.st.update_many_participant_ids( @@ -368,7 +373,7 @@ async def get_history_of_sample( if check_sample_ids: project_ids = set(r.project for r in rows) await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=True + self.author, project_ids, allowed_roles=ReadAccessRoles ) return rows diff --git a/db/python/layers/seqr.py b/db/python/layers/seqr.py index ef26c6e15..f34cff4eb 100644 --- a/db/python/layers/seqr.py +++ b/db/python/layers/seqr.py @@ -38,6 +38,7 @@ # literally the most temporary thing ever, but for complete # automation need to have sample inclusion / exclusion +from models.models.group import ReadAccessRoles from models.utils.sequencing_group_id_format import ( sequencing_group_id_format, sequencing_group_id_format_list, @@ -137,7 +138,7 @@ async def sync_dataset( project = await self.ptable.get_and_check_access_to_project_for_id( self.connection.author, project_id=self.connection.project, - readonly=True, + allowed_roles=ReadAccessRoles, ) seqr_guid = project.meta.get( @@ -731,9 +732,11 @@ def _parse_consanguity(consanguity): def process_row(row): return { - seqr_key: key_processor[sm_key](row[sm_key]) - if sm_key in key_processor - else row[sm_key] + seqr_key: ( + key_processor[sm_key](row[sm_key]) + if sm_key in key_processor + else row[sm_key] + ) for seqr_key, sm_key in seqr_map.items() if sm_key in row } diff --git a/db/python/layers/sequencing_group.py b/db/python/layers/sequencing_group.py index e68150c76..c88ad066a 100644 --- a/db/python/layers/sequencing_group.py +++ b/db/python/layers/sequencing_group.py @@ -10,6 +10,7 @@ SequencingGroupTable, ) from db.python.utils import NotFoundError +from models.models.group import ReadAccessRoles from models.models.project import ProjectId from models.models.sequencing_group import ( SequencingGroupInternal, @@ -54,7 +55,7 @@ async def get_sequencing_groups_by_ids( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) if len(groups) != len(sequencing_group_ids): @@ -84,7 +85,7 @@ async def get_sequencing_groups_by_analysis_ids( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) return groups @@ -103,7 +104,7 @@ async def query( if check_project_ids and not (filter_.project and filter_.project.in_): await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) return sequencing_groups @@ -154,7 +155,7 @@ async def get_participant_ids_sequencing_group_ids_for_sequencing_type( if check_project_ids: await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True + self.author, projects, allowed_roles=ReadAccessRoles ) return pids diff --git a/db/python/layers/web.py b/db/python/layers/web.py index 05be7d240..981edfebb 100644 --- a/db/python/layers/web.py +++ b/db/python/layers/web.py @@ -25,6 +25,7 @@ SearchItem, parse_sql_bool, ) +from models.models.group import ReadAccessRoles from models.models.web import ProjectSummaryInternal, WebProject @@ -282,7 +283,7 @@ async def get_project_summary( sample_query, values = self._project_summary_sample_query(grid_filter) ptable = ProjectPermissionsTable(self._connection) project_db = await ptable.get_and_check_access_to_project_for_id( - self.author, self.project, readonly=True + self.author, self.project, allowed_roles=ReadAccessRoles ) project = WebProject( id=project_db.id, diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 9a92e1985..0df6adac4 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -32,7 +32,7 @@ class ProjectPermissionsTable: @staticmethod def get_project_group_name(project_name: str, role: GroupProjectRole) -> str: """ - Get group name for a project, for readonly / write + Get group name for a project, for the given role """ return f'{project_name}-{role.name}' @@ -163,7 +163,7 @@ async def get_projects_accessible_by_user( self, user: str, allowed_roles: set[GroupProjectRole], - project_id_filter: list[int] | None, + project_id_filter: list[int] | None = None, ) -> list[Project]: """ Get projects that are accessible by the specified user @@ -301,6 +301,7 @@ async def check_access_to_project_id( allowed_roles: set[GroupProjectRole], raise_exception: bool = True, ) -> bool: + """Check user has access to a single project id""" project = await self.get_and_check_access_to_project_for_id( user, project_id, allowed_roles, raise_exception ) diff --git a/models/models/group.py b/models/models/group.py index 9d2bb3db7..603404d72 100644 --- a/models/models/group.py +++ b/models/models/group.py @@ -4,6 +4,16 @@ GroupProjectRole = Enum('GroupProjectRole', ['read', 'contribute', 'write']) +# These roles have read access to a project +ReadAccessRoles = { + GroupProjectRole.read, + GroupProjectRole.contribute, + GroupProjectRole.write, +} + +# Only write has full write access +FullWriteAccessRoles = {GroupProjectRole.write} + class Group(SMBase): """Row for project in 'project' table""" diff --git a/test/test_project_groups.py b/test/test_project_groups.py index d02e884f1..e545f6986 100644 --- a/test/test_project_groups.py +++ b/test/test_project_groups.py @@ -8,6 +8,7 @@ ProjectPermissionsTable, ) from db.python.utils import NotFoundError +from models.models.group import GroupProjectRole, ReadAccessRoles class TestGroupAccess(DbIsolatedTest): @@ -146,11 +147,12 @@ async def test_project_create_succeed(self): project_id = await self.pttable.create_project(g, g, self.author) # pylint: disable=protected-access - project = await self.pttable._get_project_by_id(project_id) + await self.pttable._get_project_by_id(project_id) # test that the group names make sense - self.assertIsNotNone(project.read_group_id) - self.assertIsNotNone(project.write_group_id) + # @TODO update tests to handle new role groups + # self.assertIsNotNone(project.read_group_id) + # self.assertIsNotNone(project.write_group_id) class TestProjectAccess(DbIsolatedTest): @@ -195,12 +197,12 @@ async def test_no_project_access(self): project_id = await self.pttable.create_project(g, g, self.author) with self.assertRaises(Forbidden): await self.pttable.get_and_check_access_to_project_for_id( - user=self.author, project_id=project_id, readonly=True + user=self.author, project_id=project_id, allowed_roles=ReadAccessRoles ) with self.assertRaises(Forbidden): await self.pttable.get_and_check_access_to_project_for_name( - user=self.author, project_name=g, readonly=True + user=self.author, project_name=g, allowed_roles=ReadAccessRoles ) @run_as_sync @@ -215,18 +217,20 @@ async def test_project_access_success(self): pid = await self.pttable.create_project(g, g, self.author) await self.pttable.set_group_members( - group_name=self.pttable.get_project_group_name(g, readonly=True), + group_name=self.pttable.get_project_group_name( + g, role=GroupProjectRole.read + ), members=[self.author], author=self.author, ) project_for_id = await self.pttable.get_and_check_access_to_project_for_id( - user=self.author, project_id=pid, readonly=True + user=self.author, project_id=pid, allowed_roles=ReadAccessRoles ) self.assertEqual(pid, project_for_id.id) project_for_name = await self.pttable.get_and_check_access_to_project_for_name( - user=self.author, project_name=g, readonly=True + user=self.author, project_name=g, allowed_roles=ReadAccessRoles ) self.assertEqual(pid, project_for_name.id) @@ -243,13 +247,15 @@ async def test_get_my_projects(self): pid = await self.pttable.create_project(g, g, self.author) await self.pttable.set_group_members( - group_name=self.pttable.get_project_group_name(g, readonly=True), + group_name=self.pttable.get_project_group_name( + g, role=GroupProjectRole.read + ), members=[self.author], author=self.author, ) projects = await self.pttable.get_projects_accessible_by_user( - author=self.author + user=self.author, allowed_roles=ReadAccessRoles ) self.assertEqual(1, len(projects)) diff --git a/test/testbase.py b/test/testbase.py index fa60ee66f..9c8121821 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -25,6 +25,7 @@ SMConnections, ) from db.python.tables.project import ProjectPermissionsTable +from models.models.group import FullWriteAccessRoles from models.models.project import ProjectId # use this to determine where the db directory is relatively, @@ -137,7 +138,7 @@ async def setup(): formed_connection = Connection( connection=sm_db, author=cls.author, - readonly=False, + allowed_roles=FullWriteAccessRoles, on_behalf_of=None, ar_guid=None, project=None, @@ -187,7 +188,7 @@ def setUp(self) -> None: connection=self._connection, project=self.project_id, author=self.author, - readonly=False, + allowed_roles=FullWriteAccessRoles, ar_guid=None, on_behalf_of=None, ) From 1d7358d4860f07e7d657110a6b811f7c0f2cc7b2 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Mon, 25 Mar 2024 14:05:22 +1100 Subject: [PATCH 04/29] fix typo --- db/python/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/db/python/utils.py b/db/python/utils.py index 14f655108..cc85e6248 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -77,7 +77,7 @@ def __init__( super().__init__( f'{author} does not have {allowed_roles_str} access to resources from the ' - f'follgroowing project(s), or they may not exist: {project_names_str}', + f'following project(s), or they may not exist: {project_names_str}', *args, ) From 25bb2e28b22060d159eede679426c6bce1066366 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Wed, 10 Apr 2024 16:24:35 +1000 Subject: [PATCH 05/29] fix merge problems and circular import problems --- api/routes/analysis_runner.py | 4 ++-- db/python/tables/project.py | 8 ++++++-- db/python/utils.py | 4 +--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/api/routes/analysis_runner.py b/api/routes/analysis_runner.py index 2ee383a55..893fba262 100644 --- a/api/routes/analysis_runner.py +++ b/api/routes/analysis_runner.py @@ -4,7 +4,7 @@ from api.utils.db import ( Connection, - get_project_readonly_connection, + get_project_read_connection, get_project_write_connection, ) from db.python.layers.analysis_runner import AnalysisRunnerLayer @@ -75,7 +75,7 @@ async def get_analysis_runner_logs( repository: str | None = None, access_level: str | None = None, environment: str | None = None, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_read_connection, ) -> list[AnalysisRunner]: """Get analysis runner logs""" diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 0df6adac4..c56be29f8 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -215,7 +215,9 @@ async def get_and_check_access_to_project_for_id( if not has_access: if raise_exception: raise NoProjectAccess( - [project.name], allowed_roles=allowed_roles, author=user + [project.name], + allowed_roles=[r.name for r in allowed_roles], + author=user, ) return None @@ -289,7 +291,9 @@ async def get_and_check_access_to_projects_for_ids( if missing_project_names: raise NoProjectAccess( - missing_project_names, allowed_roles=allowed_roles, author=user + missing_project_names, + allowed_roles=[r.name for r in allowed_roles], + author=user, ) return accessible_projects diff --git a/db/python/utils.py b/db/python/utils.py index cc85e6248..3b73660c4 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -3,11 +3,9 @@ import logging import os import re -from enum import Enum from typing import Any, Generic, TypeVar from models.base import SMBase -from models.models.group import GroupProjectRole T = TypeVar('T') @@ -67,7 +65,7 @@ def __init__( self, project_names: list[str], author: str, - allowed_roles: set[GroupProjectRole], + allowed_roles: list[str], *args: tuple[Any, ...], ): project_names_str = ( From 177848c005ad3a218b3324b2dd9c94566e2b4c41 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Fri, 14 Jun 2024 16:23:01 +1000 Subject: [PATCH 06/29] Update GraphQLFilters to fix generics for all_values method --- api/graphql/filters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/graphql/filters.py b/api/graphql/filters.py index d9b7e8775..8cd7017be 100644 --- a/api/graphql/filters.py +++ b/api/graphql/filters.py @@ -22,11 +22,11 @@ class GraphQLFilter(Generic[T]): contains: T | None = None icontains: T | None = None - def all_values(self): + def all_values(self) -> list[T]: """ Get all values used anywhere in a filter, useful for getting values to map later """ - v = [] + v: list[T] = [] if self.eq: v.append(self.eq) if self.in_: From cf2bde3e63fa363c3704b1fcb27bb11b4a7a4041 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Fri, 14 Jun 2024 16:23:52 +1000 Subject: [PATCH 07/29] simplify migration, liquibase rollback support is shaky So best to limit the destructive updates in the migration and do them manually --- db/project.xml | 85 ++++++++++++++++++++++---------------------------- 1 file changed, 38 insertions(+), 47 deletions(-) diff --git a/db/project.xml b/db/project.xml index beb634abb..f41f375b0 100644 --- a/db/project.xml +++ b/db/project.xml @@ -1282,77 +1282,68 @@ ALTER TABLE analysis_cohort ADD SYSTEM VERSIONING; - + SET @@system_versioning_alter_history = 1; - + - - - - - + + + + + + + + + + ALTER TABLE project_member ADD SYSTEM VERSIONING; + + + + -- copy across read and write roles from existing project table - INSERT INTO project_groups ( + INSERT INTO project_member ( SELECT - id AS project_id, - read_group_id AS group_id, - 'read' AS role, - audit_log_id - FROM project - WHERE read_group_id IS NOT NULL + p.id AS project_id, + gm.member as member, + 'reader' AS role, + coalesce(gm.audit_log_id, p.audit_log_id, g.audit_log_id) as audit_log_id + FROM project p + JOIN `group` g on g.id = p.read_group_id + JOIN group_member gm on gm.group_id = g.id + WHERE p.read_group_id IS NOT NULL UNION ALL SELECT - id AS project_id, - write_group_id AS group_id, - 'write' AS role, - audit_log_id - FROM project - WHERE write_group_id IS NOT NULL + p.id AS project_id, + gm.member as member, + 'writer' AS role, + coalesce(gm.audit_log_id, p.audit_log_id, g.audit_log_id) as audit_log_id + FROM project p + JOIN `group` g on g.id = p.write_group_id + JOIN group_member gm on gm.group_id = g.id + WHERE p.read_group_id IS NOT NULL ) - + + + - - - - - - From 18bc4bb577a2d5b88ee75f73f479d92ce7335a0a Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Fri, 14 Jun 2024 16:51:57 +1000 Subject: [PATCH 08/29] Move project permission checks from project table to connection This way they are accessible pretty much everywhere, but are only calculated once. The permission checks themselves are now synchronous and should be really fast, so no need for avoiding checking project ids --- api/utils/db.py | 212 ++++--------------- db/python/connect.py | 230 ++++++++++++++++---- db/python/tables/project.py | 408 +++++++----------------------------- db/python/utils.py | 7 +- models/models/group.py | 29 --- models/models/project.py | 47 ++++- 6 files changed, 363 insertions(+), 570 deletions(-) delete mode 100644 models/models/group.py diff --git a/api/utils/db.py b/api/utils/db.py index a850f8efb..aa0cb38a0 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -1,6 +1,7 @@ import json import logging from os import getenv +from typing import Any from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -11,8 +12,7 @@ from api.utils.gcp import email_from_id_token from db.python.connect import Connection, SMConnections from db.python.gcp_connect import BqConnection, PubSubConnection -from db.python.tables.project import ProjectPermissionsTable -from models.models.group import GroupProjectRole +from models.models.project import ProjectMemberRole EXPECTED_AUDIENCE = getenv('SM_OAUTHAUDIENCE') @@ -31,7 +31,7 @@ def get_ar_guid(request: Request) -> str | None: return request.headers.get('sm-ar-guid') -def get_extra_audit_log_values(request: Request) -> dict | None: +def get_extra_audit_log_values(request: Request) -> dict[str, Any] | None: """Get a JSON encoded dictionary from the 'sm-extra-values' header if it exists""" headers = request.headers.get('sm-extra-values') if not headers: @@ -79,170 +79,43 @@ def authenticate( raise HTTPException(status_code=401, detail='Not authenticated :(') -async def dependable_get_write_project_connection( - project: str, - request: Request, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), - on_behalf_of: str | None = Depends(get_on_behalf_of), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {'path': request.url.path} - if request.client: - meta['ip'] = request.client.host - if extra_values: - meta.update(extra_values) - - return await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - allowed_roles={GroupProjectRole.write}, - ar_guid=ar_guid, - on_behalf_of=on_behalf_of, - meta=meta, - ) - - -async def HACK_dependable_contributor_project_connection( - project: str, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {} - if extra_values: - meta.update(extra_values) - - meta['role'] = 'contributor' - - # hack by making it appear readonly - connection = await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - readonly=True, - on_behalf_of=None, - ar_guid=ar_guid, - meta=meta, - ) - - # then hack it so - connection.readonly = False - - return connection - - -async def HACK_dependable_contributor_project_connection( - project: str, - request: Request, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {"path": request.url.path} - if request.client: - meta["ip"] = request.client.host - - if extra_values: - meta.update(extra_values) - - meta['role'] = 'contributor' - - # hack by making it appear readonly - connection = await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - allowed_roles={ - GroupProjectRole.read, - GroupProjectRole.write, - GroupProjectRole.contribute, - }, - on_behalf_of=None, - ar_guid=ar_guid, - meta=meta, - ) - - -async def dependable_get_contribute_project_connection( - project: str, - request: Request, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - return await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - allowed_roles={GroupProjectRole.write, GroupProjectRole.contribute}, - on_behalf_of=None, - ar_guid=ar_guid, - meta=meta, - ) - - # then hack it so - connection.readonly = False - - return connection - - -async def dependable_get_read_project_connection( - project: str, - request: Request, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {'path': request.url.path} - if request.client: - meta['ip'] = request.client.host - - if extra_values: - meta.update(extra_values) - - return await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - allowed_roles={ - GroupProjectRole.read, - GroupProjectRole.write, - GroupProjectRole.contribute, - }, - on_behalf_of=None, - ar_guid=ar_guid, - meta=meta, - ) - +def dependable_get_project_db_connection(allowed_roles: set[ProjectMemberRole]): + """Return a partially applied dependable db connection with allowed roles applied""" + + async def dependable_project_db_connection( + project: str, + request: Request, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), + extra_values: dict[str, Any] | None = Depends(get_extra_audit_log_values), + on_behalf_of: str | None = Depends(get_on_behalf_of), + ) -> Connection: + """FastAPI handler for getting connection WITH project""" + meta = {'path': request.url.path} + if request.client: + meta['ip'] = request.client.host + + if extra_values: + meta.update(extra_values) + + return await SMConnections.get_connection_with_project( + project_name=project, + author=author, + allowed_roles=allowed_roles, + on_behalf_of=on_behalf_of, + ar_guid=ar_guid, + meta=meta, + ) -async def dependable_get_contribute_project_connection( - project: str, - request: Request, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {"path": request.url.path} - if request.client: - meta["ip"] = request.client.host - return await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - allowed_roles={GroupProjectRole.write, GroupProjectRole.contribute}, - on_behalf_of=None, - ar_guid=ar_guid, - meta=meta, - ) + return dependable_project_db_connection async def dependable_get_connection( request: Request, author: str = Depends(authenticate), ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), + extra_values: dict[str, Any] | None = Depends(get_extra_audit_log_values), + on_behalf_of: str | None = Depends(get_on_behalf_of), ): """FastAPI handler for getting connection withOUT project""" meta = {'path': request.url.path} @@ -253,7 +126,7 @@ async def dependable_get_connection( meta.update(extra_values) return await SMConnections.get_connection_no_project( - author, ar_guid=ar_guid, meta=meta + author, ar_guid=ar_guid, meta=meta, on_behalf_of=on_behalf_of ) @@ -269,7 +142,7 @@ async def dependable_get_pubsub_connection( return await PubSubConnection.get_connection_no_project(author, topic) -def validate_iap_jwt_and_get_email(iap_jwt, audience): +def validate_iap_jwt_and_get_email(iap_jwt: str, audience: str): """ Validate an IAP JWT and return email Source: https://cloud.google.com/iap/docs/signed-headers-howto @@ -292,14 +165,13 @@ def validate_iap_jwt_and_get_email(iap_jwt, audience): get_author = Depends(authenticate) -get_project_read_connection = Depends(dependable_get_read_project_connection) -get_project_contribute_connection = Depends( - dependable_get_contribute_project_connection -) -HACK_get_project_contributor_connection = Depends( - HACK_dependable_contributor_project_connection -) -get_project_write_connection = Depends(dependable_get_write_project_connection) + + +def get_project_db_connection(allowed_roles: set[ProjectMemberRole]): + """Get a project db connection with allowed roles applied""" + return Depends(dependable_get_project_db_connection(allowed_roles)) + + get_projectless_db_connection = Depends(dependable_get_connection) get_projectless_bq_connection = Depends(dependable_get_bq_connection) get_projectless_pubsub_connection = Depends(dependable_get_pubsub_connection) diff --git a/db/python/connect.py b/db/python/connect.py index 2751517a0..65d14caf3 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -3,18 +3,24 @@ """ Code for connecting to Postgres database """ - import abc import asyncio import json import logging import os +from typing import Iterable import databases from api.settings import LOG_DATABASE_QUERIES -from db.python.utils import InternalError -from models.models.group import GroupProjectRole +from db.python.tables.project import ProjectPermissionsTable +from db.python.utils import ( + InternalError, + NoProjectAccess, + NotFoundError, + ProjectDoesNotExist, +) +from models.models.project import Project, ProjectId, ProjectMemberRole logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -50,35 +56,151 @@ class Connection: def __init__( self, connection: databases.Database, - project: int | None, + project: Project | None, + project_id_map: dict[ProjectId, Project], + project_name_map: dict[str, Project], author: str, on_behalf_of: str | None, - allowed_roles: set[GroupProjectRole], ar_guid: str | None, meta: dict[str, str] | None = None, ): self.connection: databases.Database = connection - self.project: int | None = project + self.project: Project | None = project + self.project_id_map = project_id_map + self.project_name_map = project_name_map self.author: str = author self.on_behalf_of: str | None = on_behalf_of - self.allowed_roles: set[GroupProjectRole] = allowed_roles self.ar_guid: str | None = ar_guid self.meta = meta self._audit_log_id: int | None = None self._audit_log_lock = asyncio.Lock() - async def audit_log_id(self): - """Get audit_log ID for write operations, cached per connection""" - # If connection doesn't have a writeable role, don't allow getting an audit log id - if not self.allowed_roles & { - GroupProjectRole.write, - GroupProjectRole.contribute, - }: + @property + def project_id(self): + """Safely get the project id from the project model attached to the connection""" + return self.project.id if self.project is not None else None + + def all_projects(self): + """Return all projects that the current user has access to""" + return list(self.project_id_map.values()) + + def get_and_check_access_to_projects( + self, projects: Iterable[Project], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + specified projects. Raise an error if they do not. + """ + # projects that the user has some access to, but not the required access + disallowed_projects = [p for p in projects if not p.roles & allowed_roles] + + if disallowed_projects: + raise NoProjectAccess( + [p.name for p in disallowed_projects], + allowed_roles=[r.name for r in allowed_roles], + author=self.author, + ) + + return projects + + def get_and_check_access_to_projects_for_ids( + self, project_ids: Iterable[ProjectId], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project ids. Raise an error if they do not. + Also raise an error if any of the specified project ids doesn't exist or the + current user has no access to it. Return the matching projects + """ + projects = [ + self.project_id_map[id] for id in project_ids if id in self.project_id_map + ] + + # Check if any of the provided ids aren't valid project ids, or the user has + # no access to them at all. A NotFoundError is raised here rather than a + # Forbidden so as to not leak the existence of the project to those with no access. + missing_project_ids = set(project_ids) - set(p.id for p in projects) + if missing_project_ids: + missing_project_ids_str = ', '.join([str(p) for p in missing_project_ids]) + raise NotFoundError( + f'Could not find projects with ids: {missing_project_ids_str}' + ) + + return self.get_and_check_access_to_projects(projects, allowed_roles) + + def check_access_to_projects_for_ids( + self, project_ids: Iterable[ProjectId], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project ids. Raise an error if they do not. + Also raise an error if any of the specified project ids doesn't exist or the + current user has no access to it. Returns None + """ + self.get_and_check_access_to_projects_for_ids(project_ids, allowed_roles) + + def get_and_check_access_to_projects_for_names( + self, project_names: Iterable[str], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project names. Raise an error if they do not. + Also raise an error if any of the specified project names doesn't exist or + the current user has no access to it. Return the matching projects + """ + projects = [ + self.project_name_map[name] + for name in project_names + if name in self.project_name_map + ] + + # Check if any of the provided names aren't valid project names, or the user has + # no access to them at all. A NotFoundError is raised here rather than a + # Forbidden so as to not leak the existence of the project to those with no access. + missing_project_names = set(project_names) - set(p.name for p in projects) + + if missing_project_names: + missing_project_names_str = ', '.join( + [f'"{str(p)}"' for p in missing_project_names] + ) + raise NotFoundError( + f'Could not find projects with names: {missing_project_names_str}' + ) + + return self.get_and_check_access_to_projects(projects, allowed_roles) + + def check_access_to_projects_for_names( + self, project_names: Iterable[str], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project names. Raise an error if they do not. + Also raise an error if any of the specified project names doesn't exist or + the current user has no access to it. Returns None + """ + self.get_and_check_access_to_projects_for_names(project_names, allowed_roles) + + def check_access(self, allowed_roles: set[ProjectMemberRole]): + """ + Check if the current user has the specified role within the project that is + attached to the connection. If there is no project attached to the connection + this will raise an error. + """ + if self.project is None: raise InternalError( - 'Trying to get a audit_log ID, but not a write connection' + 'Connection was expected to have a project attached, but did not' + ) + if not allowed_roles & self.project.roles: + raise NoProjectAccess( + project_names=[self.project.name], + author=self.author, + allowed_roles=[r.name for r in allowed_roles], ) + async def audit_log_id(self): + """Get audit_log ID for write operations, cached per connection""" + async with self._audit_log_lock: if not self._audit_log_id: @@ -93,26 +215,18 @@ async def audit_log_id(self): on_behalf_of=self.on_behalf_of, ar_guid=self.ar_guid, comment=None, - project=self.project, + project=self.project_id, meta=self.meta, ) return self._audit_log_id - def assert_requires_project(self): - """Assert the project is set, or return an exception""" - if self.project is None: - raise InternalError( - 'An internal error has occurred when passing the project context, ' - 'please send this stacktrace to your system administrator' - ) - class DatabaseConfiguration(abc.ABC): """Base class for DatabaseConfiguration""" @abc.abstractmethod - def get_connection_string(self): + def get_connection_string(self) -> str: """Get connection string""" raise NotImplementedError @@ -132,11 +246,11 @@ class CredentialedDatabaseConfiguration(DatabaseConfiguration): def __init__( self, - dbname, - host=None, - port=None, - username=None, - password=None, + dbname: str, + host: str | None = None, + port: str | None = None, + username: str | None = None, + password: str | None = None, ): self.dbname = dbname self.host = host @@ -160,6 +274,8 @@ def get_connection_string(self): """Prepares the connection string for mysql / mariadb""" _host = self.host or 'localhost' + + assert self.username u_p = self.username if self.password: @@ -180,7 +296,7 @@ def get_connection_string(self): class SMConnections: """Contains useful functions for connecting to the database""" - _credentials: DatabaseConfiguration | None = None + _credentials: CredentialedDatabaseConfiguration | None = None @staticmethod def _get_config(): @@ -236,9 +352,48 @@ async def get_made_connection(): return conn + @staticmethod + async def get_connection_with_project( + *, + author: str, + project_name: str, + allowed_roles: set[ProjectMemberRole], + ar_guid: str, + on_behalf_of: str | None = None, + meta: dict[str, str] | None = None, + ): + """Get a db connection from a project and user""" + # maybe it makes sense to perform permission checks here too + logger.debug(f'Authenticate connection to {project_name} with {author!r}') + + conn = await SMConnections.get_made_connection() + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + + project_id_map, project_name_map = await pt.get_projects_accessible_by_user( + user=author + ) + + if project_name not in project_name_map: + raise ProjectDoesNotExist(project_name) + + connection = Connection( + connection=conn, + author=author, + project=project_name_map[project_name], + project_id_map=project_id_map, + project_name_map=project_name_map, + on_behalf_of=on_behalf_of, + ar_guid=ar_guid, + meta=meta, + ) + + connection.check_access(allowed_roles) + + return connection + @staticmethod async def get_connection_no_project( - author: str, ar_guid: str, meta: dict[str, str] + author: str, ar_guid: str, meta: dict[str, str], on_behalf_of: str | None ): """Get a db connection from a project and user""" # maybe it makes sense to perform permission checks here too @@ -246,15 +401,18 @@ async def get_connection_no_project( conn = await SMConnections.get_made_connection() - # all roles are allowed here as we don't authenticate project-less connection, - # but rely on the the endpoint to validate the resources + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + project_id_map, project_name_map = await pt.get_projects_accessible_by_user( + user=author + ) return Connection( connection=conn, author=author, project=None, - on_behalf_of=None, + on_behalf_of=on_behalf_of, ar_guid=ar_guid, - allowed_roles=set(GroupProjectRole), meta=meta, + project_id_map=project_id_map, + project_name_map=project_name_map, ) diff --git a/db/python/tables/project.py b/db/python/tables/project.py index c56be29f8..7387a1c63 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -1,20 +1,18 @@ # pylint: disable=global-statement -from typing import Any, Dict, Iterable, List +from typing import TYPE_CHECKING, Any -from async_lru import alru_cache from databases import Database +from typing_extensions import TypedDict from api.settings import is_all_access -from db.python.connect import Connection, SMConnections -from db.python.utils import ( - Forbidden, - NoProjectAccess, - NotFoundError, - get_logger, - to_db_json, -) -from models.models.group import GroupProjectRole -from models.models.project import Project, ProjectId +from db.python.utils import Forbidden, NotFoundError, get_logger, to_db_json +from models.models.project import Project, ProjectMemberRole + +# Avoid circular import for type definition +if TYPE_CHECKING: + from db.python.connect import Connection +else: + Connection = object logger = get_logger() @@ -22,6 +20,13 @@ GROUP_NAME_MEMBERS_ADMIN = 'members-admin' +class ProjectMemberWithRole(TypedDict): + """Dict passed to the update project member endpoint to specify roles for members""" + + member: str + role: str + + class ProjectPermissionsTable: """ Capture project operations and queries @@ -29,14 +34,6 @@ class ProjectPermissionsTable: table_name = 'project' - @staticmethod - def get_project_group_name(project_name: str, role: GroupProjectRole) -> str: - """ - Get group name for a project, for the given role - """ - - return f'{project_name}-{role.name}' - def __init__( self, connection: Connection | None, @@ -44,48 +41,17 @@ def __init__( database_connection: Database | None = None, ): self._connection = connection - if not database_connection and not connection: - raise ValueError( - 'Must call project permissions table with either a direct ' - 'database_connection or a fully formed connection' - ) - self.connection: Database = database_connection or connection.connection - self.gtable = GroupTable(self.connection, allow_full_access=allow_full_access) - - @staticmethod - async def get_project_connection( - *, - author: str, - project_name: str, - allowed_roles: set[GroupProjectRole], - ar_guid: str, - on_behalf_of: str | None = None, - meta: dict[str, str] | None = None, - ): - """Get a db connection from a project and user""" - # maybe it makes sense to perform permission checks here too - logger.debug(f'Authenticate connection to {project_name} with {author!r}') - - conn = await SMConnections.get_made_connection() - pt = ProjectPermissionsTable(connection=None, database_connection=conn) - - project = await pt.get_and_check_access_to_project_for_name( - user=author, project_name=project_name, allowed_roles=allowed_roles - ) + if not database_connection: + if not connection: + raise ValueError( + 'Must call project permissions table with either a direct ' + 'database_connection or a fully formed connection' + ) + self.connection = connection.connection + else: + self.connection = database_connection - # python types doesn't know this can't be none due - # to the default raise_exception of true - assert project - - return Connection( - connection=conn, - author=author, - project=project.id, - allowed_roles=allowed_roles, - on_behalf_of=on_behalf_of, - ar_guid=ar_guid, - meta=meta, - ) + self.gtable = GroupTable(self.connection, allow_full_access=allow_full_access) async def audit_log_id(self): """ @@ -97,232 +63,49 @@ async def audit_log_id(self): ) return await self._connection.audit_log_id() - # region UNPROTECTED_GETS - - @alru_cache() - async def _get_project_rows_internal(self): - """ - Internally cached get_project_rows - """ - _query = """ - SELECT id, name, meta, dataset - FROM project - """ - rows = await self.connection.fetch_all(_query) - return list(map(Project.from_db, rows)) - - async def _get_project_id_map(self): - """ - Internally cached get_project_id_map - """ - return {p.id: p for p in await self._get_project_rows_internal()} - - async def _get_project_name_map(self) -> Dict[str, int | None]: - """Get {project_name: project_id} map""" - return { - p.name: p.id - for p in await self._get_project_rows_internal() - if p.name is not None - } - - async def _get_project_by_id(self, project_id: ProjectId) -> Project: - """Get project by id""" - pmap = await self._get_project_id_map() - if project_id not in pmap: - raise NotFoundError(f'Could not find project {project_id}') - return pmap[project_id] - - async def _get_project_by_name(self, project_name: str) -> Project: - """Get project by name""" - pmap = await self._get_project_name_map() - if project_name not in pmap: - raise NotFoundError(f'Could not find project {project_name}') - return await self._get_project_by_id(pmap[project_name]) - - async def _get_projects_by_ids( - self, project_ids: Iterable[ProjectId] - ) -> List[Project]: - """Get projects by ids""" - pids = set(project_ids) - pmap = await self._get_project_id_map() - missing_pids = pids - set(pmap.keys()) - if missing_pids: - raise NotFoundError(f'Could not find projects {missing_pids}') - return [pmap[pid] for pid in pids] - - # endregion UNPROTECTED_GETS - # region AUTH - - async def get_all_projects(self, author: str): - """Get all projects""" - await self.check_project_creator_permissions(author) - return await self._get_project_rows_internal() - async def get_projects_accessible_by_user( - self, - user: str, - allowed_roles: set[GroupProjectRole], - project_id_filter: list[int] | None = None, - ) -> list[Project]: + self, user: str, return_all_projects: bool = False + ) -> tuple[dict[int, Project], dict[str, Project]]: """ Get projects that are accessible by the specified user """ - - parameters: dict[str, str | list[str] | list[int]] = { + parameters: dict[str, str] = { 'user': user, - 'allowed_roles': list(role.name for role in allowed_roles), } - project_id_filter_str = '' - if project_id_filter is not None: - parameters['project_ids'] = project_id_filter - project_id_filter_str = 'AND p.id in :project_ids' - - query = f""" - SELECT DISTINCT p.id + # In most cases we want to exclude projects that the user doesn't explicitly + # have access to. If the user is in the project creators group it may be + # necessary to return all projects whether the user has explict access to them + # or not. + where_cond = 'WHERE pm.member = :user' if return_all_projects is False else '' + + _query = f""" + SELECT + p.id, + p.name, + p.meta, + p.dataset, + GROUP_CONCAT(pm.role) as roles FROM project p - INNER JOIN project_groups pg - ON pg.project_id = p.id - AND pg.role IN :allowed_roles - INNER JOIN group_member gm ON gm.group_id = pg.group_id - WHERE gm.member = :user - {project_id_filter_str} + LEFT JOIN project_member pm + ON p.id = pm.project_id + AND pm.member = :user + {where_cond} + GROUP BY p.id """ - relevant_project_ids = await self.connection.fetch_all(query, parameters) + user_projects = await self.connection.fetch_all(_query, parameters) - projects = await self._get_projects_by_ids( - [p['id'] for p in relevant_project_ids] - ) + project_id_map: dict[int, Project] = {} + project_name_map: dict[str, Project] = {} - return projects + for row in user_projects: + project = Project.from_db(dict(row)) + project_id_map[row['id']] = project + project_name_map[row['name']] = project - async def get_and_check_access_to_project_for_id( - self, - user: str, - project_id: ProjectId, - allowed_roles: set[GroupProjectRole], - raise_exception: bool = True, - ) -> Project | None: - """Get project by id""" - project = await self._get_project_by_id(project_id) - - projects = await self.get_projects_accessible_by_user( - user, allowed_roles, [project_id] - ) - has_access = len(projects) == 1 - if not has_access: - if raise_exception: - raise NoProjectAccess( - [project.name], - allowed_roles=[r.name for r in allowed_roles], - author=user, - ) - return None - - return project - - async def get_and_check_access_to_project_for_name( - self, - user: str, - project_name: str, - allowed_roles: set[GroupProjectRole], - raise_exception: bool = True, - ) -> Project | None: - """Get project by name + perform access checks""" - project = await self._get_project_by_name(project_name) - return await self.get_and_check_access_to_project_for_id( - user, project.id, allowed_roles, raise_exception - ) - - async def get_and_check_access_to_projects_for_names( - self, user: str, project_names: list[str], allowed_roles: set[GroupProjectRole] - ): - """Get projects by names + perform access checks""" - project_name_map = await self._get_project_name_map() - - # check missing_projects - missing_project_names = set(project_names) - set(project_name_map.keys()) - if missing_project_names: - raise NotFoundError( - f'Could not find projects {', '.join(missing_project_names)}' - ) - project_ids = [project_name_map[name] for name in project_names] - # this extra filter is needed for the type sytem to be happy - # that there's no Nones in the list - filtered_project_ids = [p for p in project_ids if p is not None] - - return await self.get_and_check_access_to_projects_for_ids( - user, filtered_project_ids, allowed_roles - ) - - async def get_and_check_access_to_projects_for_ids( - self, - user: str, - project_ids: list[ProjectId], - allowed_roles: set[GroupProjectRole], - ) -> list[Project]: - """Get project by id""" - if not project_ids: - raise Forbidden( - "You don't have access to this resources, as the resource you " - "requested didn't belong to a project" - ) - - projects = await self._get_projects_by_ids(project_ids) - - # Check is any of the provided ids aren't valid project ids - missing_project_ids = set(project_ids) - set(p.id for p in projects) - missing_project_id_strs = [str(p) for p in missing_project_ids] - if missing_project_ids: - raise NotFoundError( - f'Could not find projects with ids {', '.join(missing_project_id_strs)}' - ) - - accessible_projects = await self.get_projects_accessible_by_user( - user, allowed_roles, project_ids - ) - - accessible_project_ids = set(p.id for p in accessible_projects) - missing_project_names = [ - p.name for p in projects if p.id not in accessible_project_ids - ] - - if missing_project_names: - raise NoProjectAccess( - missing_project_names, - allowed_roles=[r.name for r in allowed_roles], - author=user, - ) - - return accessible_projects - - async def check_access_to_project_id( - self, - user: str, - project_id: ProjectId, - allowed_roles: set[GroupProjectRole], - raise_exception: bool = True, - ) -> bool: - """Check user has access to a single project id""" - project = await self.get_and_check_access_to_project_for_id( - user, project_id, allowed_roles, raise_exception - ) - return project is not None - - async def check_access_to_project_ids( - self, - user: str, - project_ids: Iterable[ProjectId], - allowed_roles: set[GroupProjectRole], - ) -> bool: - """Check user has access to list of project_ids""" - # This will raise an exception if any of the specified project ids are missing - await self.get_and_check_access_to_projects_for_ids( - user, list(project_ids), allowed_roles - ) - return True + return project_id_map, project_name_map async def check_project_creator_permissions(self, author: str): """Check author has project_creator permissions""" @@ -336,6 +119,19 @@ async def check_project_creator_permissions(self, author: str): return True + async def check_member_admin_permissions(self, author: str): + """Check author has member_admin permissions""" + # check permissions in here + is_in_group = await self.gtable.check_if_member_in_group_name( + GROUP_NAME_MEMBERS_ADMIN, author + ) + if not is_in_group: + raise Forbidden( + f'User {author} does not have permission to edit project members' + ) + + return True + # endregion AUTH # region CREATE / UPDATE @@ -364,9 +160,6 @@ async def create_project( project_id = await self.connection.fetch_val(_query, values) - # pylint: disable=no-member - self._get_project_rows_internal.cache_invalidate() - return project_id async def update_project(self, project_name: str, update: dict, author: str): @@ -375,7 +168,7 @@ async def update_project(self, project_name: str, update: dict, author: str): meta = update.get('meta') - fields: Dict[str, Any] = { + fields: dict[str, Any] = { 'audit_log_id': await self.audit_log_id(), 'name': project_name, } @@ -392,12 +185,7 @@ async def update_project(self, project_name: str, update: dict, author: str): await self.connection.execute(_query, fields) - # pylint: disable=no-member - self._get_project_rows_internal.cache_invalidate() - - async def delete_project_data( - self, project_id: int, delete_project: bool, author: str - ) -> bool: + async def delete_project_data(self, project_id: int, delete_project: bool) -> bool: """ Delete data in metamist project, requires project_creator_permissions Can optionally delete the project also. @@ -405,7 +193,6 @@ async def delete_project_data( if delete_project: # stop allowing delete project with analysis-runner entries raise ValueError('2024-03-08: delete_project is no longer allowed') - await self.check_project_creator_permissions(author) async with self.connection.transaction(): _query = """ @@ -446,62 +233,25 @@ async def delete_project_data( DELETE FROM participant WHERE project = :project; DELETE FROM analysis WHERE project = :project; """ - values: dict = {'project': project_id} - if delete_project: - group_ids_rows = await self.connection.fetch_all( - """ - SELECT group_id - FROM project_groups WHERE project_id = :project' - """ - ) - group_ids = set(r['group_id'] for r in group_ids_rows) - - _query += 'DELETE FROM project WHERE id = :project;\n' - if len(group_ids) > 0: - _query += 'DELETE FROM `group` WHERE id IN :group_ids\n' - _query += ( - 'DELETE FROM `project_groups` WHERE project_id = :project\n' - ) - values['group_ids'] = list(group_ids) await self.connection.execute(_query, {'project': project_id}) - if delete_project: - # pylint: disable=no-member - self._get_project_rows_internal.cache_invalidate() - return True - async def set_group_members(self, group_name: str, members: list[str], author: str): + async def set_project_members( + self, project: Project, members: list[ProjectMemberRole] + ): """ Set group members for a group (by name) """ - - has_permission = await self.gtable.check_if_member_in_group_name( - GROUP_NAME_MEMBERS_ADMIN, author - ) - if not has_permission: - raise Forbidden( - f'User {author} does not have permission to add members to group {group_name}' - ) - group_id = await self.gtable.get_group_name_from_id(group_name) - await self.gtable.set_group_members( - group_id, members, audit_log_id=await self.audit_log_id() - ) + print('@TODO') + # group_id = await self.gtable.get_group_name_from_id(group_name) + # await self.gtable.set_group_members( + # group_id, members, audit_log_id=await self.audit_log_id() + # ) # endregion CREATE / UPDATE - async def get_seqr_projects(self) -> list[dict[str, Any]]: - """ - Get all projects with meta.is_seqr = true - """ - - all_projects = await self._get_project_rows_internal() - seqr_projects = [p for p in all_projects if p.meta.get('is_seqr')] - return seqr_projects - - # gruo - class GroupTable: """ diff --git a/db/python/utils.py b/db/python/utils.py index 3b73660c4..4cb1def73 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -3,7 +3,8 @@ import logging import os import re -from typing import Any, Generic, TypeVar +from enum import Enum +from typing import Any, Generic, Sequence, TypeVar from models.base import SMBase @@ -71,10 +72,10 @@ def __init__( project_names_str = ( ', '.join(repr(p) for p in project_names) if project_names else '' ) - allowed_roles_str = ' or '.join([r.name for r in allowed_roles]) + required_roles_str = ' or '.join(allowed_roles) super().__init__( - f'{author} does not have {allowed_roles_str} access to resources from the ' + f'{author} does not have {required_roles_str} access to resources from the ' f'following project(s), or they may not exist: {project_names_str}', *args, ) diff --git a/models/models/group.py b/models/models/group.py deleted file mode 100644 index 603404d72..000000000 --- a/models/models/group.py +++ /dev/null @@ -1,29 +0,0 @@ -from enum import Enum - -from models.base import SMBase - -GroupProjectRole = Enum('GroupProjectRole', ['read', 'contribute', 'write']) - -# These roles have read access to a project -ReadAccessRoles = { - GroupProjectRole.read, - GroupProjectRole.contribute, - GroupProjectRole.write, -} - -# Only write has full write access -FullWriteAccessRoles = {GroupProjectRole.write} - - -class Group(SMBase): - """Row for project in 'project' table""" - - id: int - name: str - role: GroupProjectRole - - @staticmethod - def from_db(kwargs): - """From DB row, with db keys""" - kwargs = dict(kwargs) - return Group(**kwargs) diff --git a/models/models/project.py b/models/models/project.py index 430c66253..53640dd41 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -1,22 +1,63 @@ import json -from typing import Optional +from enum import Enum +from typing import Any, Optional + +from pydantic import field_serializer from models.base import SMBase ProjectId = int +ProjectMemberRole = Enum( + 'ProjectMemberRole', ['reader', 'contributor', 'writer', 'data_manager'] +) + +project_member_role_names = [r.name for r in ProjectMemberRole] + +# These roles have read access to a project +ReadAccessRoles = { + ProjectMemberRole.reader, + ProjectMemberRole.contributor, + ProjectMemberRole.writer, + ProjectMemberRole.data_manager, +} + +# Only write has full write access +FullWriteAccessRoles = {ProjectMemberRole.writer} + class Project(SMBase): """Row for project in 'project' table""" id: ProjectId name: str - dataset: Optional[str] = None - meta: Optional[dict] = None + dataset: str + meta: Optional[dict[str, Any]] = None + roles: set[ProjectMemberRole] + """The roles that the current user has within the project""" + + @property + def is_test(self): + """ + Checks whether this is a test project. Comparing to the dataset is safer than + just checking whether the name ends with -test, just in case we have a non-test + project that happens to end with -test + """ + return self.name == f'{self.dataset}-test' + + @field_serializer("roles") + def serialize_roles(self, roles: set[ProjectMemberRole], _info): + return [r.name for r in roles] @staticmethod def from_db(kwargs): """From DB row, with db keys""" kwargs = dict(kwargs) kwargs['meta'] = json.loads(kwargs['meta']) if kwargs.get('meta') else {} + + # Sanitise role names and convert to enum members + role_list: list[str] = kwargs['roles'].split(',') if kwargs.get('roles') else [] + kwargs['roles'] = { + ProjectMemberRole[r] for r in role_list if r in project_member_role_names + } return Project(**kwargs) From e859d079185255e2813b0ecdf330aa0504cd59cd Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Fri, 14 Jun 2024 16:52:57 +1000 Subject: [PATCH 09/29] Update graphql loaders and schema to work with new permissions Also some QOL fixes for graphql types so that the context is now properly typed --- api/graphql/loaders.py | 86 +++++++------- api/graphql/schema.py | 248 +++++++++++++++++++++-------------------- 2 files changed, 174 insertions(+), 160 deletions(-) diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 87889ae84..eef71ebf9 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -4,13 +4,14 @@ import dataclasses import enum from collections import defaultdict -from typing import Any +from typing import Any, TypedDict from fastapi import Request from strawberry.dataloader import DataLoader from api.utils import group_by from api.utils.db import get_projectless_db_connection +from db.python.connect import Connection from db.python.layers import ( AnalysisLayer, AssayLayer, @@ -23,7 +24,6 @@ from db.python.tables.analysis import AnalysisFilter from db.python.tables.assay import AssayFilter from db.python.tables.family import FamilyFilter -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter from db.python.utils import GenericFilter, NotFoundError, get_hashable_value @@ -39,7 +39,6 @@ ) from models.models.audit_log import AuditLogInternal from models.models.family import PedRowInternal -from models.models.group import ReadAccessRoles class LoaderKeys(enum.Enum): @@ -79,14 +78,14 @@ class LoaderKeys(enum.Enum): SEQUENCING_GROUPS_FOR_ANALYSIS = 'sequencing_groups_for_analysis' -loaders = {} +loaders: dict[LoaderKeys, Any] = {} -def connected_data_loader(id_: LoaderKeys, cache=True): +def connected_data_loader(id_: LoaderKeys, cache: bool = True): """Provide connection to a data loader""" def connected_data_loader_caller(fn): - def inner(connection): + def inner(connection: Connection): async def wrapped(*args, **kwargs): return await fn(*args, **kwargs, connection=connection) @@ -110,7 +109,7 @@ def connected_data_loader_with_params( """ def connected_data_loader_caller(fn): - def inner(connection): + def inner(connection: Connection): async def wrapped(query: list[dict[str, Any]]) -> list[Any]: by_key: dict[tuple, Any] = {} @@ -159,7 +158,7 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]: @connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS) async def load_audit_logs_by_ids( - audit_log_ids: list[int], connection + audit_log_ids: list[int], connection: Connection ) -> list[AuditLogInternal | None]: """ DataLoader: get_audit_logs_by_ids @@ -172,7 +171,7 @@ async def load_audit_logs_by_ids( @connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS) async def load_audit_logs_by_analysis_ids( - analysis_ids: list[int], connection + analysis_ids: list[int], connection: Connection ) -> list[list[AuditLogInternal]]: """ DataLoader: get_audit_logs_by_analysis_ids @@ -184,7 +183,7 @@ async def load_audit_logs_by_analysis_ids( @connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list) async def load_assays_by_samples( - connection, ids, filter: AssayFilter + connection: Connection, ids, filter: AssayFilter ) -> dict[int, list[AssayInternal]]: """ DataLoader: get_assays_for_sample_ids @@ -200,7 +199,7 @@ async def load_assays_by_samples( @connected_data_loader(LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS) async def load_assays_by_sequencing_groups( - sequencing_group_ids: list[int], connection + sequencing_group_ids: list[int], connection: Connection ) -> list[list[AssayInternal]]: """ Get all assays belong to the sequencing groups @@ -209,7 +208,7 @@ async def load_assays_by_sequencing_groups( # group by all last fields, in case we add more assays = await assaylayer.get_assays_for_sequencing_group_ids( - sequencing_group_ids=sequencing_group_ids, check_project_ids=False + sequencing_group_ids=sequencing_group_ids ) return [assays.get(sg, []) for sg in sequencing_group_ids] @@ -219,7 +218,7 @@ async def load_assays_by_sequencing_groups( LoaderKeys.SAMPLES_FOR_PARTICIPANTS, default_factory=list ) async def load_samples_for_participant_ids( - ids: list[int], filter: SampleFilter, connection + ids: list[int], filter: SampleFilter, connection: Connection ) -> dict[int, list[SampleInternal]]: """ DataLoader: get_samples_for_participant_ids @@ -232,7 +231,7 @@ async def load_samples_for_participant_ids( @connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_IDS) async def load_sequencing_groups_for_ids( - sequencing_group_ids: list[int], connection + sequencing_group_ids: list[int], connection: Connection ) -> list[SequencingGroupInternal]: """ DataLoader: get_sequencing_groups_by_ids @@ -249,7 +248,7 @@ async def load_sequencing_groups_for_ids( LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES, default_factory=list ) async def load_sequencing_groups_for_samples( - connection, ids: list[int], filter: SequencingGroupFilter + connection: Connection, ids: list[int], filter: SequencingGroupFilter ) -> dict[int, list[SequencingGroupInternal]]: """ Has format [(sample_id: int, sequencing_type?: string)] @@ -265,7 +264,7 @@ async def load_sequencing_groups_for_samples( @connected_data_loader(LoaderKeys.SAMPLES_FOR_IDS) async def load_samples_for_ids( - sample_ids: list[int], connection + sample_ids: list[int], connection: Connection ) -> list[SampleInternal]: """ DataLoader: get_samples_for_ids @@ -281,7 +280,7 @@ async def load_samples_for_ids( LoaderKeys.SAMPLES_FOR_PROJECTS, default_factory=list ) async def load_samples_for_projects( - connection, ids: list[ProjectId], filter: SampleFilter + connection: Connection, ids: list[ProjectId], filter: SampleFilter ): """ DataLoader: get_samples_for_project_ids @@ -295,7 +294,7 @@ async def load_samples_for_projects( @connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_IDS) async def load_participants_for_ids( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[ParticipantInternal]: """ DataLoader: get_participants_by_ids @@ -313,7 +312,7 @@ async def load_participants_for_ids( @connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS) async def load_sequencing_groups_for_analysis_ids( - analysis_ids: list[int], connection + analysis_ids: list[int], connection: Connection ) -> list[list[SequencingGroupInternal]]: """ DataLoader: get_samples_for_analysis_ids @@ -328,7 +327,7 @@ async def load_sequencing_groups_for_analysis_ids( LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS, default_factory=list ) async def load_sequencing_groups_for_project_ids( - ids: list[int], filter: SequencingGroupFilter, connection + ids: list[int], filter: SequencingGroupFilter, connection: Connection ) -> dict[int, list[SequencingGroupInternal]]: """ DataLoader: get_sequencing_groups_for_project_ids @@ -342,39 +341,33 @@ async def load_sequencing_groups_for_project_ids( @connected_data_loader(LoaderKeys.PROJECTS_FOR_IDS) -async def load_projects_for_ids(project_ids: list[int], connection) -> list[Project]: +async def load_projects_for_ids( + project_ids: list[int], connection: Connection +) -> list[Project]: """ Get projects by IDs """ - pttable = ProjectPermissionsTable(connection) - projects = await pttable.get_and_check_access_to_projects_for_ids( - user=connection.user, project_ids=project_ids, allowed_roles=ReadAccessRoles - ) - - p_by_id = {p.id: p for p in projects} - projects = [p_by_id.get(p) for p in project_ids] + projects = [connection.project_id_map.get(p) for p in project_ids] return [p for p in projects if p is not None] @connected_data_loader(LoaderKeys.FAMILIES_FOR_PARTICIPANTS) async def load_families_for_participants( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[list[FamilyInternal]]: """ Get families of participants, noting a participant can be in multiple families """ flayer = FamilyLayer(connection) - fam_map = await flayer.get_families_by_participants( - participant_ids=participant_ids, check_project_ids=False - ) + fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids) return [fam_map.get(p, []) for p in participant_ids] @connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_FAMILIES) async def load_participants_for_families( - family_ids: list[int], connection + family_ids: list[int], connection: Connection ) -> list[list[ParticipantInternal]]: """Get all participants in a family, doesn't include affected statuses""" player = ParticipantLayer(connection) @@ -384,7 +377,7 @@ async def load_participants_for_families( @connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_PROJECTS) async def load_participants_for_projects( - project_ids: list[ProjectId], connection + project_ids: list[ProjectId], connection: Connection ) -> list[list[ParticipantInternal]]: """ Get all participants in a project @@ -406,7 +399,7 @@ async def load_participants_for_projects( async def load_analyses_for_sequencing_groups( ids: list[int], filter_: AnalysisFilter, - connection, + connection: Connection, ) -> dict[int, list[AnalysisInternal]]: """ Type: (sequencing_group_id: int, status?: AnalysisStatus, type?: str) @@ -424,7 +417,7 @@ async def load_analyses_for_sequencing_groups( @connected_data_loader(LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS) async def load_phenotypes_for_participants( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[dict]: """ Data loader for phenotypes for participants @@ -438,7 +431,7 @@ async def load_phenotypes_for_participants( @connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS) async def load_families_for_ids( - family_ids: list[int], connection + family_ids: list[int], connection: Connection ) -> list[FamilyInternal]: """ DataLoader: get_families_for_ids @@ -451,7 +444,7 @@ async def load_families_for_ids( @connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES) async def load_family_participants_for_families( - family_ids: list[int], connection + family_ids: list[int], connection: Connection ) -> list[list[PedRowInternal]]: """ DataLoader: get_family_participants_for_families @@ -464,7 +457,7 @@ async def load_family_participants_for_families( @connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS) async def load_family_participants_for_participants( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[list[PedRowInternal]]: """data loader for family participants for participants @@ -485,12 +478,21 @@ async def load_family_participants_for_participants( return [fp_map.get(pid, []) for pid in participant_ids] +class GraphQLContext(TypedDict): + """Basic dict type for GraphQL context to be passed to resolvers""" + + loaders: dict[LoaderKeys, Any] + connection: Connection + + async def get_context( - request: Request, connection=get_projectless_db_connection -): # pylint: disable=unused-argument + request: Request, # pylint: disable=unused-argument + connection: Connection = get_projectless_db_connection, +) -> GraphQLContext: """Get loaders / cache context for strawberyy GraphQL""" mapped_loaders = {k: fn(connection) for k, fn in loaders.items()} + return { 'connection': connection, - **mapped_loaders, + 'loaders': mapped_loaders, } diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 259481bfa..42edd8e72 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -18,7 +18,7 @@ GraphQLMetaFilter, graphql_meta_filter_to_internal_filter, ) -from api.graphql.loaders import LoaderKeys, get_context +from api.graphql.loaders import GraphQLContext, LoaderKeys, get_context from db.python import enum_tables from db.python.layers import ( AnalysisLayer, @@ -34,7 +34,6 @@ from db.python.tables.assay import AssayFilter from db.python.tables.cohort import CohortFilter, CohortTemplateFilter from db.python.tables.family import FamilyFilter -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter from db.python.utils import GenericFilter @@ -53,8 +52,7 @@ ) from models.models.analysis_runner import AnalysisRunnerInternal from models.models.family import PedRowInternal -from models.models.group import ReadAccessRoles -from models.models.project import ProjectId +from models.models.project import ProjectId, ReadAccessRoles from models.models.sample import sample_id_transform_to_raw from models.utils.cohort_id_format import cohort_id_format, cohort_id_transform_to_raw from models.utils.cohort_template_id_format import ( @@ -73,7 +71,7 @@ continue def create_function(_enum): - async def m(info: Info) -> list[str]: + async def m(info: Info[GraphQLContext, 'Query']) -> list[str]: return await _enum(info.context['connection']).get() m.__name__ = _enum.get_enum_name() @@ -114,20 +112,18 @@ def from_internal(internal: CohortInternal) -> 'GraphQLCohort': @strawberry.field() async def template( - self, info: Info, root: 'GraphQLCohort' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' ) -> 'GraphQLCohortTemplate': connection = info.context['connection'] template = await CohortLayer(connection).get_template_by_cohort_id( cohort_id_transform_to_raw(root.id) ) - ptable = ProjectPermissionsTable(connection) - projects = await ptable.get_and_check_access_to_projects_for_ids( - user=connection.author, + projects = connection.get_and_check_access_to_projects_for_ids( project_ids=( template.criteria.projects if template.criteria.projects else [] ), - readonly=True, + allowed_roles=ReadAccessRoles, ) project_names = [p.name for p in projects if p.name] @@ -137,7 +133,7 @@ async def template( @strawberry.field() async def sequencing_groups( - self, info: Info, root: 'GraphQLCohort' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' ) -> list['GraphQLSequencingGroup']: connection = info.context['connection'] cohort_layer = CohortLayer(connection) @@ -151,10 +147,9 @@ async def sequencing_groups( @strawberry.field() async def analyses( - self, info: Info, root: 'GraphQLCohort' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' ) -> list['GraphQLAnalysis']: connection = info.context['connection'] - connection.project = root.project_id internal_analysis = await AnalysisLayer(connection).query( AnalysisFilter( cohort_id=GenericFilter(in_=[cohort_id_transform_to_raw(root.id)]), @@ -163,8 +158,10 @@ async def analyses( return [GraphQLAnalysis.from_internal(a) for a in internal_analysis] @strawberry.field() - async def project(self, info: Info, root: 'GraphQLCohort') -> 'GraphQLProject': - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' + ) -> 'GraphQLProject': + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @@ -196,9 +193,9 @@ def from_internal( @strawberry.field() async def project( - self, info: Info, root: 'GraphQLCohortTemplate' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohortTemplate' ) -> 'GraphQLProject': - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @@ -224,7 +221,7 @@ def from_internal(internal: Project) -> 'GraphQLProject': @strawberry.field() async def analysis_runner( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', ar_guid: GraphQLFilter[str] | None = None, author: GraphQLFilter[str] | None = None, @@ -248,7 +245,7 @@ async def analysis_runner( @strawberry.field() async def pedigree( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', internal_family_ids: list[int] | None = None, replace_with_participant_external_ids: bool = True, @@ -276,7 +273,7 @@ async def pedigree( @strawberry.field() async def families( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLProject', id: GraphQLFilter[int] | None = None, external_id: GraphQLFilter[str] | None = None, @@ -295,23 +292,23 @@ async def families( @strawberry.field() async def participants( - self, info: Info, root: 'Project' + self, info: Info[GraphQLContext, 'Query'], root: 'Project' ) -> list['GraphQLParticipant']: - loader = info.context[LoaderKeys.PARTICIPANTS_FOR_PROJECTS] + loader = info.context['loaders'][LoaderKeys.PARTICIPANTS_FOR_PROJECTS] participants = await loader.load(root.id) return [GraphQLParticipant.from_internal(p) for p in participants] @strawberry.field() async def samples( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLProject', type: GraphQLFilter[str] | None = None, external_id: GraphQLFilter[str] | None = None, id: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, ) -> list['GraphQLSample']: - loader = info.context[LoaderKeys.SAMPLES_FOR_PROJECTS] + loader = info.context['loaders'][LoaderKeys.SAMPLES_FOR_PROJECTS] filter_ = SampleFilter( type=type.to_internal_filter() if type else None, external_id=external_id.to_internal_filter() if external_id else None, @@ -324,7 +321,7 @@ async def samples( @strawberry.field() async def sequencing_groups( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLProject', id: GraphQLFilter[str] | None = None, external_id: GraphQLFilter[str] | None = None, @@ -333,7 +330,7 @@ async def sequencing_groups( platform: GraphQLFilter[str] | None = None, active_only: GraphQLFilter[bool] | None = None, ) -> list['GraphQLSequencingGroup']: - loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS] + loader = info.context['loaders'][LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS] filter_ = SequencingGroupFilter( id=( id.to_internal_filter_mapped(sequencing_group_id_transform_to_raw) @@ -356,7 +353,7 @@ async def sequencing_groups( @strawberry.field() async def analyses( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', type: GraphQLFilter[str] | None = None, status: GraphQLFilter[GraphQLAnalysisStatus] | None = None, @@ -366,7 +363,6 @@ async def analyses( ids: GraphQLFilter[int] | None = None, ) -> list['GraphQLAnalysis']: connection = info.context['connection'] - connection.project = root.id internal_analysis = await AnalysisLayer(connection).query( AnalysisFilter( id=ids.to_internal_filter() if ids else None, @@ -391,7 +387,7 @@ async def analyses( @strawberry.field() async def cohorts( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', id: GraphQLFilter[str] | None = None, name: GraphQLFilter[str] | None = None, @@ -400,7 +396,6 @@ async def cohorts( timestamp: GraphQLFilter[datetime.datetime] | None = None, ) -> list['GraphQLCohort']: connection = info.context['connection'] - connection.project = root.id c_filter = CohortFilter( id=id.to_internal_filter_mapped(cohort_id_transform_to_raw) if id else None, @@ -472,23 +467,25 @@ def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis': @strawberry.field async def sequencing_groups( - self, info: Info, root: 'GraphQLAnalysis' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysis' ) -> list['GraphQLSequencingGroup']: - loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS] + loader = info.context['loaders'][LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS] sgs = await loader.load(root.id) return [GraphQLSequencingGroup.from_internal(sg) for sg in sgs] @strawberry.field - async def project(self, info: Info, root: 'GraphQLAnalysis') -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysis' + ) -> GraphQLProject: + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project) return GraphQLProject.from_internal(project) @strawberry.field async def audit_logs( - self, info: Info, root: 'GraphQLAnalysis' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysis' ) -> list[GraphQLAuditLog]: - loader = info.context[LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS] + loader = info.context['loaders'][LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS] audit_logs = await loader.load(root.id) return [GraphQLAuditLog.from_internal(audit_log) for audit_log in audit_logs] @@ -517,25 +514,27 @@ def from_internal(internal: FamilyInternal) -> 'GraphQLFamily': ) @strawberry.field - async def project(self, info: Info, root: 'GraphQLFamily') -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamily' + ) -> GraphQLProject: + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @strawberry.field async def participants( - self, info: Info, root: 'GraphQLFamily' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamily' ) -> list['GraphQLParticipant']: - participants = await info.context[LoaderKeys.PARTICIPANTS_FOR_FAMILIES].load( - root.id - ) + participants = await info.context['loaders'][ + LoaderKeys.PARTICIPANTS_FOR_FAMILIES + ].load(root.id) return [GraphQLParticipant.from_internal(p) for p in participants] @strawberry.field async def family_participants( - self, info: Info, root: 'GraphQLFamily' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamily' ) -> list['GraphQLFamilyParticipant']: - family_participants = await info.context[ + family_participants = await info.context['loaders'][ LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES ].load(root.id) return [ @@ -558,17 +557,17 @@ class GraphQLFamilyParticipant: @strawberry.field async def participant( - self, info: Info, root: 'GraphQLFamilyParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamilyParticipant' ) -> 'GraphQLParticipant': - loader = info.context[LoaderKeys.PARTICIPANTS_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.PARTICIPANTS_FOR_IDS] participant = await loader.load(root.participant_id) return GraphQLParticipant.from_internal(participant) @strawberry.field async def family( - self, info: Info, root: 'GraphQLFamilyParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamilyParticipant' ) -> GraphQLFamily: - loader = info.context[LoaderKeys.FAMILIES_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.FAMILIES_FOR_IDS] family = await loader.load(root.family_id) return GraphQLFamily.from_internal(family) @@ -613,7 +612,7 @@ def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant': @strawberry.field async def samples( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant', type: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, @@ -626,28 +625,32 @@ async def samples( ) q = {'id': root.id, 'filter': filter_} - samples = await info.context[LoaderKeys.SAMPLES_FOR_PARTICIPANTS].load(q) + samples = await info.context['loaders'][ + LoaderKeys.SAMPLES_FOR_PARTICIPANTS + ].load(q) return [GraphQLSample.from_internal(s) for s in samples] @strawberry.field async def phenotypes( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> strawberry.scalars.JSON: - loader = info.context[LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS] + loader = info.context['loaders'][LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS] return await loader.load(root.id) @strawberry.field async def families( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> list[GraphQLFamily]: - fams = await info.context[LoaderKeys.FAMILIES_FOR_PARTICIPANTS].load(root.id) + fams = await info.context['loaders'][LoaderKeys.FAMILIES_FOR_PARTICIPANTS].load( + root.id + ) return [GraphQLFamily.from_internal(f) for f in fams] @strawberry.field async def family_participants( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> list[GraphQLFamilyParticipant]: - family_participants = await info.context[ + family_participants = await info.context['loaders'][ LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS ].load(root.id) return [ @@ -655,18 +658,20 @@ async def family_participants( ] @strawberry.field - async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' + ) -> GraphQLProject: + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @strawberry.field async def audit_log( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> GraphQLAuditLog | None: if root.audit_log_id is None: return None - loader = info.context[LoaderKeys.AUDIT_LOGS_BY_IDS] + loader = info.context['loaders'][LoaderKeys.AUDIT_LOGS_BY_IDS] audit_log = await loader.load(root.audit_log_id) return GraphQLAuditLog.from_internal(audit_log) @@ -702,23 +707,27 @@ def from_internal(sample: SampleInternal) -> 'GraphQLSample': @strawberry.field async def participant( - self, info: Info, root: 'GraphQLSample' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample' ) -> GraphQLParticipant | None: if root.participant_id is None: return None - loader_participants_for_ids = info.context[LoaderKeys.PARTICIPANTS_FOR_IDS] + loader_participants_for_ids = info.context['loaders'][ + LoaderKeys.PARTICIPANTS_FOR_IDS + ] participant = await loader_participants_for_ids.load(root.participant_id) return GraphQLParticipant.from_internal(participant) @strawberry.field async def assays( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample', type: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, ) -> list['GraphQLAssay']: - loader_assays_for_sample_ids = info.context[LoaderKeys.ASSAYS_FOR_SAMPLES] + loader_assays_for_sample_ids = info.context['loaders'][ + LoaderKeys.ASSAYS_FOR_SAMPLES + ] filter_ = AssayFilter( type=type.to_internal_filter() if type else None, meta=meta, @@ -729,14 +738,18 @@ async def assays( return [GraphQLAssay.from_internal(assay) for assay in assays] @strawberry.field - async def project(self, info: Info, root: 'GraphQLSample') -> GraphQLProject: - project = await info.context[LoaderKeys.PROJECTS_FOR_IDS].load(root.project_id) + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample' + ) -> GraphQLProject: + project = await info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS].load( + root.project_id + ) return GraphQLProject.from_internal(project) @strawberry.field async def sequencing_groups( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample', id: GraphQLFilter[str] | None = None, type: GraphQLFilter[str] | None = None, @@ -745,7 +758,7 @@ async def sequencing_groups( meta: GraphQLMetaFilter | None = None, active_only: GraphQLFilter[bool] | None = None, ) -> list['GraphQLSequencingGroup']: - loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES] + loader = info.context['loaders'][LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES] _filter = SequencingGroupFilter( id=( @@ -801,15 +814,17 @@ def from_internal(internal: SequencingGroupInternal) -> 'GraphQLSequencingGroup' ) @strawberry.field - async def sample(self, info: Info, root: 'GraphQLSequencingGroup') -> GraphQLSample: - loader = info.context[LoaderKeys.SAMPLES_FOR_IDS] + async def sample( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSequencingGroup' + ) -> GraphQLSample: + loader = info.context['loaders'][LoaderKeys.SAMPLES_FOR_IDS] sample = await loader.load(root.sample_id) return GraphQLSample.from_internal(sample) @strawberry.field async def analyses( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLSequencingGroup', status: GraphQLFilter[GraphQLAnalysisStatus] | None = None, type: GraphQLFilter[str] | None = None, @@ -818,15 +833,13 @@ async def analyses( project: GraphQLFilter[str] | None = None, ) -> list[GraphQLAnalysis]: connection = info.context['connection'] - loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS] + loader = info.context['loaders'][LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS] _project_filter: GenericFilter[ProjectId] | None = None if project: - ptable = ProjectPermissionsTable(connection) - project_ids = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, - project_names=project_ids, + project_names = project.all_values() + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, allowed_roles=ReadAccessRoles, ) project_id_map: dict[str, int] = { @@ -856,9 +869,9 @@ async def analyses( @strawberry.field async def assays( - self, info: Info, root: 'GraphQLSequencingGroup' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSequencingGroup' ) -> list['GraphQLAssay']: - loader = info.context[LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS] + loader = info.context['loaders'][LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS] assays = await loader.load(root.internal_id) return [GraphQLAssay.from_internal(assay) for assay in assays] @@ -889,8 +902,10 @@ def from_internal(internal: AssayInternal) -> 'GraphQLAssay': ) @strawberry.field - async def sample(self, info: Info, root: 'GraphQLAssay') -> GraphQLSample: - loader = info.context[LoaderKeys.SAMPLES_FOR_IDS] + async def sample( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAssay' + ) -> GraphQLSample: + loader = info.context['loaders'][LoaderKeys.SAMPLES_FOR_IDS] sample = await loader.load(root.sample_id) return GraphQLSample.from_internal(sample) @@ -944,9 +959,9 @@ def from_internal(internal: AnalysisRunnerInternal) -> 'GraphQLAnalysisRunner': @strawberry.field async def project( - self, info: Info, root: 'GraphQLAnalysisRunner' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysisRunner' ) -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.internal_project) return GraphQLProject.from_internal(project) @@ -956,26 +971,25 @@ class Query: # entry point to graphql. """GraphQL Queries""" @strawberry.field() - def enum(self, info: Info) -> GraphQLEnum: # type: ignore + def enum(self, info: Info[GraphQLContext, 'Query']) -> GraphQLEnum: # type: ignore return GraphQLEnum() @strawberry.field() async def cohort_templates( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, ) -> list[GraphQLCohortTemplate]: connection = info.context['connection'] cohort_layer = CohortLayer(connection) - ptable = ProjectPermissionsTable(connection) project_name_map: dict[str, int] = {} project_filter = None if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, allowed_roles=ReadAccessRoles ) project_name_map = {p.name: p.id for p in projects} project_filter = project.to_internal_filter_mapped( @@ -996,10 +1010,9 @@ async def cohort_templates( external_templates = [] for template in cohort_templates: - template_projects = await ptable.get_and_check_access_to_projects_for_ids( - user=connection.author, + template_projects = connection.get_and_check_access_to_projects_for_ids( project_ids=template.criteria.projects or [], - readonly=True, + allowed_roles=ReadAccessRoles, ) template_project_names = [p.name for p in template_projects if p.name] external_templates.append( @@ -1011,7 +1024,7 @@ async def cohort_templates( @strawberry.field() async def cohorts( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, name: GraphQLFilter[str] | None = None, @@ -1021,13 +1034,12 @@ async def cohorts( connection = info.context['connection'] cohort_layer = CohortLayer(connection) - ptable = ProjectPermissionsTable(connection) project_name_map: dict[str, int] = {} project_filter = None if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, allowed_roles=ReadAccessRoles ) project_name_map = {p.name: p.id for p in projects} project_filter = project.to_internal_filter_mapped( @@ -1052,18 +1064,19 @@ async def cohorts( return [GraphQLCohort.from_internal(cohort) for cohort in cohorts] @strawberry.field() - async def project(self, info: Info, name: str) -> GraphQLProject: + async def project( + self, info: Info[GraphQLContext, 'Query'], name: str + ) -> GraphQLProject: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection) - project = await ptable.get_and_check_access_to_project_for_name( - user=connection.author, project_name=name, allowed_roles=ReadAccessRoles + projects = connection.get_and_check_access_to_projects_for_names( + project_names=[name], allowed_roles=ReadAccessRoles ) - return GraphQLProject.from_internal(project) + return GraphQLProject.from_internal(next(p for p in projects)) @strawberry.field async def sample( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, type: GraphQLFilter[str] | None = None, @@ -1073,7 +1086,6 @@ async def sample( active: GraphQLFilter[bool] | None = None, ) -> list[GraphQLSample]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection) slayer = SampleLayer(connection) if not id and not project: @@ -1085,8 +1097,7 @@ async def sample( project_name_map: dict[str, int] = {} if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, + projects = connection.get_and_check_access_to_projects_for_names( project_names=project_names, allowed_roles=ReadAccessRoles, ) @@ -1115,7 +1126,7 @@ async def sample( @strawberry.field async def sequencing_groups( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, sample_id: GraphQLFilter[str] | None = None, @@ -1130,7 +1141,6 @@ async def sequencing_groups( ) -> list[GraphQLSequencingGroup]: connection = info.context['connection'] sglayer = SequencingGroupLayer(connection) - ptable = ProjectPermissionsTable(connection) if not (project or sample_id or id): raise ValueError('Must filter by project, sample or id') @@ -1139,8 +1149,7 @@ async def sequencing_groups( if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, + projects = connection.get_and_check_access_to_projects_for_names( project_names=project_names, allowed_roles=ReadAccessRoles, ) @@ -1178,36 +1187,39 @@ async def sequencing_groups( return [GraphQLSequencingGroup.from_internal(sg) for sg in sgs] @strawberry.field - async def assay(self, info: Info, id: int) -> GraphQLAssay: + async def assay(self, info: Info[GraphQLContext, 'Query'], id: int) -> GraphQLAssay: connection = info.context['connection'] slayer = AssayLayer(connection) assay = await slayer.get_assay_by_id(id) return GraphQLAssay.from_internal(assay) @strawberry.field - async def participant(self, info: Info, id: int) -> GraphQLParticipant: - loader = info.context[LoaderKeys.PARTICIPANTS_FOR_IDS] + async def participant( + self, info: Info[GraphQLContext, 'Query'], id: int + ) -> GraphQLParticipant: + loader = info.context['loaders'][LoaderKeys.PARTICIPANTS_FOR_IDS] return GraphQLParticipant.from_internal(await loader.load(id)) @strawberry.field() - async def family(self, info: Info, family_id: int) -> GraphQLFamily: + async def family( + self, info: Info[GraphQLContext, 'Query'], family_id: int + ) -> GraphQLFamily: connection = info.context['connection'] family = await FamilyLayer(connection).get_family_by_internal_id(family_id) return GraphQLFamily.from_internal(family) @strawberry.field - async def my_projects(self, info: Info) -> list[GraphQLProject]: + async def my_projects( + self, info: Info[GraphQLContext, 'Query'] + ) -> list[GraphQLProject]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection) - projects = await ptable.get_projects_accessible_by_user( - connection.author, allowed_roles=ReadAccessRoles - ) + projects = connection.all_projects() return [GraphQLProject.from_internal(p) for p in projects] @strawberry.field async def analysis_runner( self, - info: Info, + info: Info[GraphQLContext, 'Query'], ar_guid: str, ) -> GraphQLAnalysisRunner: if not ar_guid: From e30590f1bc6755e3d17af9496da5e0d8432c829b Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Fri, 14 Jun 2024 16:53:43 +1000 Subject: [PATCH 10/29] Update routes, layers and table files to work with new permissions --- api/routes/analysis.py | 80 ++++++++--------- api/routes/analysis_runner.py | 21 ++--- api/routes/assay.py | 13 ++- api/routes/cohort.py | 47 +++++----- api/routes/family.py | 18 ++-- api/routes/imports.py | 5 +- api/routes/participant.py | 31 +++---- api/routes/project.py | 85 ++++++++++-------- api/routes/sample.py | 41 ++++----- api/routes/sequencing_groups.py | 8 +- api/routes/web.py | 19 ++-- db/python/layers/analysis.py | 60 ++++++------- db/python/layers/analysis_runner.py | 12 +-- db/python/layers/assay.py | 78 ++++++++--------- db/python/layers/audit_log.py | 15 ++-- db/python/layers/base.py | 3 +- db/python/layers/cohort.py | 28 +++--- db/python/layers/family.py | 84 ++++++++---------- db/python/layers/participant.py | 92 +++++++++---------- db/python/layers/sample.py | 126 +++++++++++---------------- db/python/layers/search.py | 14 ++- db/python/layers/seqr.py | 54 ++++++++---- db/python/layers/sequencing_group.py | 44 ++++------ db/python/layers/web.py | 32 ++++--- db/python/tables/analysis.py | 16 ++-- db/python/tables/analysis_runner.py | 2 +- db/python/tables/assay.py | 10 +-- db/python/tables/base.py | 1 + db/python/tables/cohort.py | 2 +- db/python/tables/family.py | 6 +- db/python/tables/participant.py | 6 +- db/python/tables/sample.py | 9 +- db/python/tables/sequencing_group.py | 4 +- 33 files changed, 491 insertions(+), 575 deletions(-) diff --git a/api/routes/analysis.py b/api/routes/analysis.py index 75e3af4a3..2ce161d1f 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -11,22 +11,16 @@ from api.utils.dates import parse_date_only_string from api.utils.db import ( Connection, - get_project_read_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from api.utils.export import ExportType from db.python.layers.analysis import AnalysisLayer from db.python.tables.analysis import AnalysisFilter -from db.python.tables.project import ProjectPermissionsTable from db.python.utils import GenericFilter from models.enums import AnalysisStatus -from models.models.analysis import ( - Analysis, - AnalysisInternal, - ProportionalDateTemporalMethod, -) -from models.models.group import ReadAccessRoles +from models.models.analysis import Analysis, ProportionalDateTemporalMethod +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.utils.sequencing_group_id_format import ( sequencing_group_id_format, sequencing_group_id_format_list, @@ -96,7 +90,8 @@ def to_filter(self, project_id_map: dict[str, int]) -> AnalysisFilter: @router.put('/{project}/', operation_id='createAnalysis', response_model=int) async def create_analysis( - analysis: Analysis, connection: Connection = get_project_write_connection + analysis: Analysis, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> int: """Create a new analysis""" @@ -136,14 +131,14 @@ async def update_analysis( ) async def get_all_sample_ids_without_analysis_type( analysis_type: str, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """get_all_sample_ids_without_analysis_type""" atable = AnalysisLayer(connection) - assert connection.project + assert connection.project_id sequencing_group_ids = ( await atable.get_all_sequencing_group_ids_without_analysis_type( - connection.project, analysis_type + connection.project_id, analysis_type ) ) return { @@ -156,12 +151,12 @@ async def get_all_sample_ids_without_analysis_type( operation_id='getIncompleteAnalyses', ) async def get_incomplete_analyses( - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get analyses with status queued or in-progress""" atable = AnalysisLayer(connection) - assert connection.project - results = await atable.get_incomplete_analyses(project=connection.project) + assert connection.project_id + results = await atable.get_incomplete_analyses(project=connection.project_id) return [r.to_external() for r in results] @@ -171,13 +166,13 @@ async def get_incomplete_analyses( ) async def get_latest_complete_analysis_for_type( analysis_type: str, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get (SINGLE) latest complete analysis for some analysis type""" alayer = AnalysisLayer(connection) - assert connection.project + assert connection.project_id analysis = await alayer.get_latest_complete_analysis_for_type( - project=connection.project, analysis_type=analysis_type + project=connection.project_id, analysis_type=analysis_type ) return analysis.to_external() @@ -189,16 +184,16 @@ async def get_latest_complete_analysis_for_type( async def get_latest_complete_analysis_for_type_post( analysis_type: str, meta: dict[str, Any] = Body(..., embed=True), # type: ignore[assignment] - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """ Get SINGLE latest complete analysis for some analysis type (you can specify meta attributes in this route) """ alayer = AnalysisLayer(connection) - assert connection.project + assert connection.project_id analysis = await alayer.get_latest_complete_analysis_for_type( - project=connection.project, + project=connection.project_id, analysis_type=analysis_type, meta=meta, ) @@ -229,11 +224,8 @@ async def query_analyses( if not query.projects: raise ValueError('Must specify "projects"') - pt = ProjectPermissionsTable(connection) - projects = await pt.get_and_check_access_to_projects_for_names( - user=connection.author, - project_names=query.projects, - allowed_roles=ReadAccessRoles, + projects = connection.get_and_check_access_to_projects_for_names( + query.projects, ReadAccessRoles ) project_name_map = {p.name: p.id for p in projects} atable = AnalysisLayer(connection) @@ -243,27 +235,24 @@ async def query_analyses( @router.get('/analysis-runner', operation_id='getAnalysisRunnerLog') async def get_analysis_runner_log( - project_names: list[str] = Query(None), # type: ignore - # author: str = None, # not implemented yet, uncomment when we do - output_dir: str = None, - ar_guid: str = None, + project_names: list[str], + output_dir: str, + ar_guid: str | None = None, connection: Connection = get_projectless_db_connection, -) -> list[AnalysisInternal]: +) -> list[Analysis]: """ Get log for the analysis-runner, useful for checking this history of analysis """ atable = AnalysisLayer(connection) project_ids = None - if project_names: - pt = ProjectPermissionsTable(connection) - projects = await pt.get_and_check_access_to_projects_for_names( - connection.author, project_names, allowed_roles=ReadAccessRoles - ) - project_ids = [p.id for p in projects] + + projects = connection.get_and_check_access_to_projects_for_names( + project_names, allowed_roles=ReadAccessRoles + ) + project_ids = [p.id for p in projects] results = await atable.get_analysis_runner_log( project_ids=project_ids, - # author=author, output_dir=output_dir, ar_guid=ar_guid, ) @@ -278,7 +267,7 @@ async def get_analysis_runner_log( async def get_sample_reads_map( export_type: ExportType = ExportType.JSON, sequencing_types: list[str] = Query(None), # type: ignore - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """ Get map of ExternalSampleId pathToCram InternalSeqGroupID for seqr @@ -293,9 +282,9 @@ async def get_sample_reads_map( """ at = AnalysisLayer(connection) - assert connection.project + assert connection.project_id objs = await at.get_sample_cram_path_map_for_seqr( - project=connection.project, sequencing_types=sequencing_types + project=connection.project_id, sequencing_types=sequencing_types ) for r in objs: @@ -311,7 +300,7 @@ async def get_sample_reads_map( writer = csv.writer(output, delimiter=export_type.get_delimiter()) writer.writerows(rows) - basefn = f'{connection.project}-seqr-igv-paths-{date.today().isoformat()}' + basefn = f'{connection.project_id}-seqr-igv-paths-{date.today().isoformat()}' return StreamingResponse( iter([output.getvalue()]), @@ -345,9 +334,8 @@ async def get_proportionate_map( } } """ - pt = ProjectPermissionsTable(connection) - project_list = await pt.get_and_check_access_to_projects_for_names( - connection.author, projects, allowed_roles=ReadAccessRoles + project_list = connection.get_and_check_access_to_projects_for_names( + projects, allowed_roles=ReadAccessRoles ) project_ids = [p.id for p in project_list] diff --git a/api/routes/analysis_runner.py b/api/routes/analysis_runner.py index 893fba262..0d1d235f9 100644 --- a/api/routes/analysis_runner.py +++ b/api/routes/analysis_runner.py @@ -2,15 +2,12 @@ from fastapi import APIRouter -from api.utils.db import ( - Connection, - get_project_read_connection, - get_project_write_connection, -) +from api.utils.db import Connection, get_project_db_connection from db.python.layers.analysis_runner import AnalysisRunnerLayer from db.python.tables.analysis_runner import AnalysisRunnerFilter from db.python.utils import GenericFilter from models.models.analysis_runner import AnalysisRunner, AnalysisRunnerInternal +from models.models.project import FullWriteAccessRoles, ReadAccessRoles router = APIRouter(prefix='/analysis-runner', tags=['analysis-runner']) @@ -32,13 +29,13 @@ async def create_analysis_runner_log( # pylint: disable=too-many-arguments output_path: str, hail_version: str | None = None, cwd: str | None = None, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> str: """Create a new analysis runner log""" alayer = AnalysisRunnerLayer(connection) - if not connection.project: + if not connection.project_id: raise ValueError('Project not set') analysis_id = await alayer.insert_analysis_runner_entry( @@ -58,7 +55,7 @@ async def create_analysis_runner_log( # pylint: disable=too-many-arguments batch_url=batch_url, submitting_user=submitting_user, meta=meta, - project=connection.project, + project=connection.project_id, audit_log_id=None, output_path=output_path, ) @@ -75,13 +72,13 @@ async def get_analysis_runner_logs( repository: str | None = None, access_level: str | None = None, environment: str | None = None, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> list[AnalysisRunner]: """Get analysis runner logs""" atable = AnalysisRunnerLayer(connection) - if not connection.project: + if not connection.project_id: raise ValueError('Project not set') filter_ = AnalysisRunnerFilter( @@ -90,9 +87,9 @@ async def get_analysis_runner_logs( repository=GenericFilter(eq=repository), access_level=GenericFilter(eq=access_level), environment=GenericFilter(eq=environment), - project=GenericFilter(eq=connection.project), + project=GenericFilter(eq=connection.project_id), ) logs = await atable.query(filter_) - return [log.to_external({connection.project: project}) for log in logs] + return [log.to_external({connection.project_id: project}) for log in logs] diff --git a/api/routes/assay.py b/api/routes/assay.py index dcf584f54..e543f2ae8 100644 --- a/api/routes/assay.py +++ b/api/routes/assay.py @@ -4,16 +4,15 @@ from api.utils.db import ( Connection, - get_project_read_connection, + get_project_db_connection, get_projectless_db_connection, ) from db.python.layers.assay import AssayLayer from db.python.tables.assay import AssayFilter -from db.python.tables.project import ProjectPermissionsTable from db.python.utils import GenericFilter from models.base import SMBase from models.models.assay import AssayUpsert -from models.models.group import ReadAccessRoles +from models.models.project import ReadAccessRoles from models.utils.sample_id_format import sample_id_transform_to_raw_list router = APIRouter(prefix='/assay', tags=['assay']) @@ -57,7 +56,8 @@ async def get_assay_by_id( '/{project}/external_id/{external_id}/details', operation_id='getAssayByExternalId' ) async def get_assay_by_external_id( - external_id: str, connection=get_project_read_connection + external_id: str, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get an assay by ONE of its external identifiers""" assay_layer = AssayLayer(connection) @@ -84,12 +84,11 @@ async def get_assays_by_criteria( ): """Get assays by criteria""" assay_layer = AssayLayer(connection) - pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if criteria.projects: - project_list = await pt.get_and_check_access_to_projects_for_names( - connection.author, criteria.projects, allowed_roles=ReadAccessRoles + project_list = connection.get_and_check_access_to_projects_for_names( + criteria.projects, allowed_roles=ReadAccessRoles ) pids = [p.id for p in project_list] diff --git a/api/routes/cohort.py b/api/routes/cohort.py index 0ea555a2f..4a5568926 100644 --- a/api/routes/cohort.py +++ b/api/routes/cohort.py @@ -1,10 +1,9 @@ from fastapi import APIRouter -from api.utils.db import Connection, HACK_get_project_contributor_connection +from api.utils.db import Connection, get_project_db_connection from db.python.layers.cohort import CohortLayer -from db.python.tables.project import ProjectPermissionsTable from models.models.cohort import CohortBody, CohortCriteria, CohortTemplate, NewCohort -from models.models.project import ProjectId +from models.models.project import ProjectId, ProjectMemberRole, ReadAccessRoles from models.utils.cohort_template_id_format import ( cohort_template_id_format, cohort_template_id_transform_to_raw, @@ -17,7 +16,9 @@ async def create_cohort_from_criteria( cohort_spec: CohortBody, cohort_criteria: CohortCriteria | None = None, - connection: Connection = HACK_get_project_contributor_connection, + connection: Connection = get_project_db_connection( + {ProjectMemberRole.writer, ProjectMemberRole.contributor} + ), dry_run: bool = False, ) -> NewCohort: """ @@ -35,24 +36,21 @@ async def create_cohort_from_criteria( internal_project_ids: list[ProjectId] = [] - if cohort_criteria: - if cohort_criteria.projects: - pt = ProjectPermissionsTable(connection) - projects = await pt.get_and_check_access_to_projects_for_names( - user=connection.author, - project_names=cohort_criteria.projects, - readonly=True, - ) - if projects: - internal_project_ids = [p.id for p in projects if p.id] + if cohort_criteria and cohort_criteria.projects: + projects = connection.get_and_check_access_to_projects_for_names( + project_names=cohort_criteria.projects, allowed_roles=ReadAccessRoles + ) + + internal_project_ids = [p.id for p in projects if p.id] template_id_raw = ( cohort_template_id_transform_to_raw(cohort_spec.template_id) if cohort_spec.template_id else None ) + assert connection.project_id cohort_output = await cohort_layer.create_cohort_from_criteria( - project_to_write=connection.project, + project_to_write=connection.project_id, description=cohort_spec.description, cohort_name=cohort_spec.name, dry_run=dry_run, @@ -70,7 +68,9 @@ async def create_cohort_from_criteria( @router.post('/{project}/cohort_template', operation_id='createCohortTemplate') async def create_cohort_template( template: CohortTemplate, - connection: Connection = HACK_get_project_contributor_connection, + connection: Connection = get_project_db_connection( + {ProjectMemberRole.writer, ProjectMemberRole.contributor} + ), ) -> str: """ Create a cohort template with the given name and sample/sequencing group IDs. @@ -83,20 +83,19 @@ async def create_cohort_template( criteria_project_ids: list[ProjectId] = [] if template.criteria.projects: - pt = ProjectPermissionsTable(connection) - projects_for_criteria = await pt.get_and_check_access_to_projects_for_names( - user=connection.author, + projects_for_criteria = connection.get_and_check_access_to_projects_for_names( project_names=template.criteria.projects, - readonly=False, + allowed_roles=ReadAccessRoles, ) - if projects_for_criteria: - criteria_project_ids = [p.id for p in projects_for_criteria if p.id] + criteria_project_ids = [p.id for p in projects_for_criteria if p.id] + assert connection.project_id cohort_raw_id = await cohort_layer.create_cohort_template( cohort_template=template.to_internal( - criteria_projects=criteria_project_ids, template_project=connection.project + criteria_projects=criteria_project_ids, + template_project=connection.project_id, ), - project=connection.project, + project=connection.project_id, ) return cohort_template_id_format(cohort_raw_id) diff --git a/api/routes/family.py b/api/routes/family.py index aa7827090..c1bccb714 100644 --- a/api/routes/family.py +++ b/api/routes/family.py @@ -11,8 +11,7 @@ from api.utils.db import ( Connection, - get_project_read_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from api.utils.export import ExportType @@ -21,6 +20,7 @@ from db.python.tables.family import FamilyFilter from db.python.utils import GenericFilter from models.models.family import Family +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.utils.sample_id_format import sample_id_transform_to_raw_list router = APIRouter(prefix='/family', tags=['family']) @@ -41,7 +41,7 @@ async def import_pedigree( has_header: bool = False, create_missing_participants: bool = False, perform_sex_check: bool = True, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """Import a pedigree""" delimiter = guess_delimiter_by_upload_file_obj(file) @@ -75,7 +75,7 @@ async def get_pedigree( replace_with_family_external_ids: bool = True, include_header: bool = True, empty_participant_value: Optional[str] = None, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), include_participants_not_in_families: bool = False, ): """ @@ -87,9 +87,9 @@ async def get_pedigree( """ family_layer = FamilyLayer(connection) - assert connection.project + assert connection.project_id pedigree_dicts = await family_layer.get_pedigree( - project=connection.project, + project=connection.project_id, family_ids=internal_family_ids, replace_with_participant_external_ids=replace_with_participant_external_ids, replace_with_family_external_ids=replace_with_family_external_ids, @@ -120,7 +120,7 @@ async def get_pedigree( ] writer.writerows(pedigree_rows) - basefn = f'{connection.project}-{date.today().isoformat()}' + basefn = f'{connection.project_id}-{date.today().isoformat()}' if internal_family_ids: basefn += '-'.join(str(fm) for fm in internal_family_ids) @@ -142,7 +142,7 @@ async def get_pedigree( async def get_families( participant_ids: Optional[List[int]] = Query(None), sample_ids: Optional[List[str]] = Query(None), - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> List[Family]: """Get families for some project""" family_layer = FamilyLayer(connection) @@ -182,7 +182,7 @@ async def import_families( file: UploadFile = File(...), has_header: bool = True, delimiter: str | None = None, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """Import a family csv""" delimiter = guess_delimiter_by_upload_file_obj(file, default_delimiter=delimiter) diff --git a/api/routes/imports.py b/api/routes/imports.py index 7f8f84bf2..66dc0531e 100644 --- a/api/routes/imports.py +++ b/api/routes/imports.py @@ -4,12 +4,13 @@ from fastapi import APIRouter, File, UploadFile -from api.utils.db import Connection, get_project_write_connection +from api.utils.db import Connection, get_project_db_connection from api.utils.extensions import guess_delimiter_by_upload_file_obj from db.python.layers.participant import ( ExtraParticipantImporterHandler, ParticipantLayer, ) +from models.models.project import FullWriteAccessRoles router = APIRouter(prefix='/import', tags=['import']) @@ -23,7 +24,7 @@ async def import_individual_metadata_manifest( file: UploadFile = File(...), delimiter: Optional[str] = None, extra_participants_method: ExtraParticipantImporterHandler = ExtraParticipantImporterHandler.FAIL, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Import individual metadata manifest diff --git a/api/routes/participant.py b/api/routes/participant.py index 4a9409f34..f6a9ddc22 100644 --- a/api/routes/participant.py +++ b/api/routes/participant.py @@ -8,14 +8,14 @@ from api.utils.db import ( Connection, - get_project_read_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from api.utils.export import ExportType from db.python.layers.participant import ParticipantLayer from models.base import SMBase from models.models.participant import ParticipantUpsert +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.sequencing_group import sequencing_group_id_format router = APIRouter(prefix='/participant', tags=['participant']) @@ -25,7 +25,7 @@ '/{project}/fill-in-missing-participants', operation_id='fillInMissingParticipants' ) async def fill_in_missing_participants( - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Create a corresponding participant (if required) @@ -47,13 +47,13 @@ async def get_individual_metadata_template_for_seqr( external_participant_ids: list[str] | None = Query(default=None), # type: ignore[assignment] # pylint: disable=invalid-name replace_with_participant_external_ids: bool = True, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get individual metadata template for SEQR as a CSV""" participant_layer = ParticipantLayer(connection) - assert connection.project + assert connection.project_id resp = await participant_layer.get_seqr_individual_template( - project=connection.project, + project=connection.project_id, external_participant_ids=external_participant_ids, replace_with_participant_external_ids=replace_with_participant_external_ids, ) @@ -90,14 +90,14 @@ async def get_individual_metadata_template_for_seqr( async def get_id_map_by_external_ids( external_participant_ids: list[str], allow_missing: bool = False, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get ID map of participants, by external_id""" player = ParticipantLayer(connection) return await player.get_id_map_by_external_ids( external_participant_ids, allow_missing=allow_missing, - project=connection.project, + project=connection.project_id, ) @@ -121,7 +121,7 @@ async def get_external_participant_id_to_sequencing_group_id( sequencing_type: str = None, export_type: ExportType = ExportType.JSON, flip_columns: bool = False, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """ Get csv / tsv export of external_participant_id to sequencing_group_id @@ -136,10 +136,10 @@ async def get_external_participant_id_to_sequencing_group_id( :param flip_columns: Set to True when exporting for seqr """ player = ParticipantLayer(connection) - # this wants project ID (connection.project) - assert connection.project + # this wants project ID (connection.project_id) + assert connection.project_id m = await player.get_external_participant_id_to_internal_sequencing_group_id_map( - project=connection.project, sequencing_type=sequencing_type + project=connection.project_id, sequencing_type=sequencing_type ) rows = [[pid, sequencing_group_id_format(sgid)] for pid, sgid in m] @@ -189,7 +189,7 @@ async def update_participant( ) async def upsert_participants( participants: list[ParticipantUpsert], - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Upserts a list of participants with samples and sequences @@ -214,12 +214,13 @@ class QueryParticipantCriteria(SMBase): ) async def get_participants( criteria: QueryParticipantCriteria, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get participants, default ALL participants in project""" player = ParticipantLayer(connection) + assert connection.project_id participants = await player.get_participants( - project=connection.project, + project=connection.project_id, external_participant_ids=( criteria.external_participant_ids if criteria else None ), diff --git a/api/routes/project.py b/api/routes/project.py index e375066a3..799525fd6 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -1,30 +1,26 @@ -from typing import List - from fastapi import APIRouter -from api.utils.db import Connection, get_projectless_db_connection -from db.python.tables.project import ProjectPermissionsTable -from models.models.group import FullWriteAccessRoles, GroupProjectRole, ReadAccessRoles -from models.models.project import Project +from api.utils.db import ( + Connection, + get_project_db_connection, + get_projectless_db_connection, +) +from db.python.tables.project import ProjectMemberWithRole, ProjectPermissionsTable +from models.models.project import FullWriteAccessRoles, Project, ProjectMemberRole router = APIRouter(prefix='/project', tags=['project']) -@router.get('/all', operation_id='getAllProjects', response_model=List[Project]) -async def get_all_projects(connection=get_projectless_db_connection): +@router.get('/all', operation_id='getAllProjects', response_model=list[Project]) +async def get_all_projects(connection: Connection = get_projectless_db_connection): """Get list of projects""" - ptable = ProjectPermissionsTable(connection) - return await ptable.get_all_projects(author=connection.author) + return connection.all_projects() -@router.get('/', operation_id='getMyProjects', response_model=List[str]) -async def get_my_projects(connection=get_projectless_db_connection): +@router.get('/', operation_id='getMyProjects', response_model=list[str]) +async def get_my_projects(connection: Connection = get_projectless_db_connection): """Get projects I have access to""" - ptable = ProjectPermissionsTable(connection) - projects = await ptable.get_projects_accessible_by_user( - user=connection.author, allowed_roles=ReadAccessRoles - ) - return [p.name for p in projects] + return [p.name for p in connection.all_projects()] @router.put('/', operation_id='createProject') @@ -38,6 +34,10 @@ async def create_project( Create a new project """ ptable = ProjectPermissionsTable(connection) + + # Creating a project requires the project creator permission + await ptable.check_project_creator_permissions(author=connection.author) + pid = await ptable.create_project( project_name=name, dataset_name=dataset, @@ -57,28 +57,32 @@ async def create_project( @router.get('/seqr/all', operation_id='getSeqrProjects') async def get_seqr_projects(connection: Connection = get_projectless_db_connection): """Get SM projects that should sync to seqr""" - ptable = ProjectPermissionsTable(connection) - return await ptable.get_seqr_projects() + projects = connection.all_projects() + return [p for p in projects if p.meta and p.meta.get('is_seqr')] @router.post('/{project}/update', operation_id='updateProject') async def update_project( - project: str, project_update_model: dict, - connection: Connection = get_projectless_db_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """Update a project by project name""" ptable = ProjectPermissionsTable(connection) + # Updating a project additionally requires the project creator permission + await ptable.check_project_creator_permissions(author=connection.author) + project = connection.project + assert project return await ptable.update_project( - project_name=project, update=project_update_model, author=connection.author + project_name=project.name, update=project_update_model, author=connection.author ) @router.delete('/{project}', operation_id='deleteProjectData') async def delete_project_data( - project: str, delete_project: bool = False, - connection: Connection = get_projectless_db_connection, + connection: Connection = get_project_db_connection( + {ProjectMemberRole.writer, ProjectMemberRole.data_manager} + ), ): """ Delete all data in a project by project name. @@ -86,11 +90,22 @@ async def delete_project_data( Requires READ access + project-creator permissions """ ptable = ProjectPermissionsTable(connection) - p_obj = await ptable.get_and_check_access_to_project_for_name( - user=connection.author, project_name=project, allowed_roles=FullWriteAccessRoles + + assert connection.project + + # Allow data manager role to delete test projects + data_manager_deleting_test = ( + connection.project.is_test + and connection.project.roles & {ProjectMemberRole.data_manager} ) + if not data_manager_deleting_test: + # Otherwise, deleting a project additionally requires the project creator permission + await ptable.check_project_creator_permissions(author=connection.author) + success = await ptable.delete_project_data( - project_id=p_obj.id, delete_project=delete_project, author=connection.author + project_id=connection.project.id, + delete_project=delete_project, + author=connection.author, ) return {'success': success} @@ -98,21 +113,17 @@ async def delete_project_data( @router.patch('/{project}/members', operation_id='updateProjectMembers') async def update_project_members( - project: str, - members: list[str], - # @TODO change this to accept a role - role: str, - connection: Connection = get_projectless_db_connection, + members: list[ProjectMemberWithRole], + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Update project members for specific read / write group. Not that this is protected by access to a specific access group """ ptable = ProjectPermissionsTable(connection) - await ptable.set_group_members( - group_name=ptable.get_project_group_name(project, role=GroupProjectRole(role)), - members=members, - author=connection.author, - ) + + await ptable.check_member_admin_permissions(author=connection.author) + assert connection.project + await ptable.set_project_members(project=connection.project, members=members) return {'success': True} diff --git a/api/routes/sample.py b/api/routes/sample.py index 8919c5d52..b265239f2 100644 --- a/api/routes/sample.py +++ b/api/routes/sample.py @@ -2,14 +2,12 @@ from api.utils.db import ( Connection, - get_project_read_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from db.python.layers.sample import SampleLayer -from db.python.tables.project import ProjectPermissionsTable from models.base import SMBase -from models.models.group import ReadAccessRoles +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.sample import SampleUpsert from models.utils.sample_id_format import ( # Sample, sample_id_format, @@ -24,7 +22,8 @@ @router.put('/{project}/', response_model=str, operation_id='createSample') async def create_sample( - sample: SampleUpsert, connection: Connection = get_project_write_connection + sample: SampleUpsert, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> str: """Creates a new sample, and returns the internal sample ID""" st = SampleLayer(connection) @@ -38,7 +37,7 @@ async def create_sample( ) async def upsert_samples( samples: list[SampleUpsert], - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Upserts a list of samples with sequencing-groups, @@ -63,7 +62,7 @@ async def upsert_samples( async def get_sample_id_map_by_external( external_ids: list[str], allow_missing: bool = False, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get map of sample IDs, { [externalId]: internal_sample_id }""" st = SampleLayer(connection) @@ -92,12 +91,14 @@ async def get_sample_id_map_by_internal( '/{project}/id-map/internal/all', operation_id='getAllSampleIdMapByInternal' ) async def get_all_sample_id_map_by_internal( - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get map of ALL sample IDs, { [internal_id]: external_sample_id }""" st = SampleLayer(connection) - assert connection.project - result = await st.get_all_sample_id_map_by_internal_ids(project=connection.project) + assert connection.project_id + result = await st.get_all_sample_id_map_by_internal_ids( + project=connection.project_id + ) return {sample_id_format(k): v for k, v in result.items()} @@ -108,12 +109,15 @@ async def get_all_sample_id_map_by_internal( operation_id='getSampleByExternalId', ) async def get_sample_by_external_id( - external_id: str, connection: Connection = get_project_read_connection + external_id: str, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get sample by external ID""" st = SampleLayer(connection) - assert connection.project - result = await st.get_single_by_external_id(external_id, project=connection.project) + assert connection.project_id + result = await st.get_single_by_external_id( + external_id, project=connection.project_id + ) return result.to_external() @@ -130,11 +134,6 @@ class GetSamplesCriteria(SMBase): @router.post('/', operation_id='getSamples') async def get_samples( criteria: GetSamplesCriteria, - meta: dict = None, - participant_ids: list[int] = None, - # project_ids is inaccurately named, it should be `project_names` - project_ids: list[str] = None, - active: bool = Body(default=True), connection: Connection = get_projectless_db_connection, ): """ @@ -142,11 +141,10 @@ async def get_samples( """ st = SampleLayer(connection) - pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if criteria.project_ids: - projects = await pt.get_and_check_access_to_projects_for_names( - connection.author, criteria.project_ids, allowed_roles=ReadAccessRoles + projects = connection.get_and_check_access_to_projects_for_names( + criteria.project_ids, allowed_roles=ReadAccessRoles ) pids = [p.id for p in projects] @@ -162,7 +160,6 @@ async def get_samples( participant_ids=criteria.participant_ids, project_ids=pids, active=criteria.active, - check_project_ids=True, ) return [r.to_external() for r in result] diff --git a/api/routes/sequencing_groups.py b/api/routes/sequencing_groups.py index d9842c5b1..09c3c44a1 100644 --- a/api/routes/sequencing_groups.py +++ b/api/routes/sequencing_groups.py @@ -5,11 +5,11 @@ from api.utils.db import ( Connection, - get_project_read_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from db.python.layers.sequencing_group import SequencingGroupLayer +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.sequencing_group import SequencingGroupUpsertInternal from models.utils.sample_id_format import sample_id_format from models.utils.sequencing_group_id_format import ( # Sample, @@ -42,7 +42,7 @@ async def get_sequencing_group( @router.get('/project/{project}', operation_id='getAllSequencingGroupIdsBySampleByType') async def get_all_sequencing_group_ids_by_sample_by_type( - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> dict[str, dict[str, list[str]]]: """Creates a new sample, and returns the internal sample ID""" st = SequencingGroupLayer(connection) @@ -60,7 +60,7 @@ async def get_all_sequencing_group_ids_by_sample_by_type( async def update_sequencing_group( sequencing_group_id: str, sequencing_group: SequencingGroupMetaUpdateModel, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> bool: """Update the meta fields of a sequencing group""" st = SequencingGroupLayer(connection) diff --git a/api/routes/web.py b/api/routes/web.py index e52a5620e..c4a58f778 100644 --- a/api/routes/web.py +++ b/api/routes/web.py @@ -10,16 +10,14 @@ from api.utils.db import ( Connection, - get_project_read_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from db.python.layers.search import SearchLayer from db.python.layers.seqr import SeqrLayer from db.python.layers.web import SearchItem, WebLayer -from db.python.tables.project import ProjectPermissionsTable from models.enums.web import SeqrDatasetType -from models.models.group import ReadAccessRoles +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.search import SearchResponse from models.models.web import PagingLinks, ProjectSummary @@ -43,7 +41,7 @@ async def get_project_summary( grid_filter: list[SearchItem], limit: int = 20, token: Optional[int] = 0, - connection: Connection = get_project_read_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> ProjectSummary: """Creates a new sample, and returns the internal sample ID""" st = WebLayer(connection) @@ -75,16 +73,15 @@ async def get_project_summary( @router.get( '/search', response_model=SearchResponseModel, operation_id='searchByKeyword' ) -async def search_by_keyword(keyword: str, connection=get_projectless_db_connection): +async def search_by_keyword( + keyword: str, connection: Connection = get_projectless_db_connection +): """ This searches the keyword, in families, participants + samples in the projects that you are a part of (automatically). """ # raise ValueError("Test") - pt = ProjectPermissionsTable(connection) - projects = await pt.get_projects_accessible_by_user( - connection.author, allowed_roles=ReadAccessRoles - ) + projects = connection.all_projects() pmap = {p.id: p for p in projects} responses = await SearchLayer(connection).search( keyword, project_ids=list(pmap.keys()) @@ -114,7 +111,7 @@ async def sync_seqr_project( sync_saved_variants: bool = True, sync_cram_map: bool = True, post_slack_notification: bool = True, - connection=get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Sync a metamist project with its seqr project (for a specific sequence type) diff --git a/db/python/layers/analysis.py b/db/python/layers/analysis.py index db469468b..308c74e1f 100644 --- a/db/python/layers/analysis.py +++ b/db/python/layers/analysis.py @@ -20,8 +20,7 @@ ProportionalDateTemporalMethod, SequencingGroupInternal, ) -from models.models.group import FullWriteAccessRoles, ReadAccessRoles -from models.models.project import ProjectId +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles from models.models.sequencing_group import SequencingGroupInternalId ES_ANALYSIS_OBJ_INTRO_DATE = datetime.date(2022, 6, 21) @@ -54,13 +53,13 @@ def __init__(self, connection: Connection): # GETS - async def get_analysis_by_id(self, analysis_id: int, check_project_id=True): + async def get_analysis_by_id(self, analysis_id: int): """Get analysis by ID""" project, analysis = await self.at.get_analysis_by_id(analysis_id) - if check_project_id: - await self.ptable.check_access_to_project_id( - self.author, project, allowed_roles=ReadAccessRoles - ) + + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return analysis @@ -115,21 +114,17 @@ async def get_sample_cram_path_map_for_seqr( participant_ids=participant_ids, ) - async def query( - self, filter_: AnalysisFilter, check_project_ids: bool = True - ) -> list[AnalysisInternal]: + async def query(self, filter_: AnalysisFilter) -> list[AnalysisInternal]: """Query analyses""" analyses = await self.at.query(filter_) if not analyses: return [] - if check_project_ids and not filter_.project: - await self.ptable.check_access_to_project_ids( - self.author, - set(a.project for a in analyses if a.project is not None), - allowed_roles=ReadAccessRoles, - ) + self.connection.check_access_to_projects_for_ids( + set(a.project for a in analyses if a.project is not None), + allowed_roles=ReadAccessRoles, + ) return analyses @@ -160,8 +155,8 @@ async def get_cram_size_proportionate_map( if start_date < datetime.date(2020, 1, 1): raise ValueError(f'start_date ({start_date}) must be after 2020-01-01') - project_objs = await self.ptable.get_and_check_access_to_projects_for_ids( - project_ids=projects, user=self.author, allowed_roles=ReadAccessRoles + project_objs = self.connection.get_and_check_access_to_projects_for_ids( + project_ids=projects, allowed_roles=ReadAccessRoles ) project_name_map = {p.id: p.name for p in project_objs} @@ -558,14 +553,13 @@ async def create_analysis( ) async def add_sequencing_groups_to_analysis( - self, analysis_id: int, sequencing_group_ids: list[int], check_project_id=True + self, analysis_id: int, sequencing_group_ids: list[int] ): """Add samples to an analysis (through the linked table)""" - if check_project_id: - project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=FullWriteAccessRoles - ) + project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) return await self.at.add_sequencing_groups_to_analysis( analysis_id=analysis_id, sequencing_group_ids=sequencing_group_ids @@ -577,16 +571,14 @@ async def update_analysis( status: AnalysisStatus, meta: dict[str, Any] = None, output: str | None = None, - check_project_id=True, ): """ Update the status of an analysis, set timestamp_completed if relevant """ - if check_project_id: - project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=FullWriteAccessRoles - ) + project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) await self.at.update_analysis( analysis_id=analysis_id, @@ -597,17 +589,15 @@ async def update_analysis( async def get_analysis_runner_log( self, - project_ids: list[int] = None, - # author: str = None, - output_dir: str = None, - ar_guid: str = None, + project_ids: list[int], + output_dir: str | None = None, + ar_guid: str | None = None, ) -> list[AnalysisInternal]: """ Get log for the analysis-runner, useful for checking this history of analysis """ return await self.at.get_analysis_runner_log( project_ids, - # author=author, output_dir=output_dir, ar_guid=ar_guid, ) diff --git a/db/python/layers/analysis_runner.py b/db/python/layers/analysis_runner.py index 7cb504847..3422ee5c6 100644 --- a/db/python/layers/analysis_runner.py +++ b/db/python/layers/analysis_runner.py @@ -2,6 +2,7 @@ from db.python.layers.base import BaseLayer from db.python.tables.analysis_runner import AnalysisRunnerFilter, AnalysisRunnerTable from models.models.analysis_runner import AnalysisRunnerInternal +from models.models.project import ReadAccessRoles class AnalysisRunnerLayer(BaseLayer): @@ -15,19 +16,18 @@ def __init__(self, connection: Connection): # GETS async def query( - self, filter_: AnalysisRunnerFilter, check_project_ids: bool = True + self, filter_: AnalysisRunnerFilter ) -> list[AnalysisRunnerInternal]: """Get analysis runner logs""" logs = await self.at.query(filter_) if not logs: return [] - if check_project_ids: - project_ids = set(log.project for log in logs) + project_ids = set(log.project for log in logs) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=ReadAccessRoles + ) return logs diff --git a/db/python/layers/assay.py b/db/python/layers/assay.py index 9d6a04c9d..0acb49b29 100644 --- a/db/python/layers/assay.py +++ b/db/python/layers/assay.py @@ -6,7 +6,7 @@ from db.python.tables.sample import SampleTable from db.python.utils import NoOpAenter from models.models.assay import AssayInternal, AssayUpsertInternal -from models.models.group import FullWriteAccessRoles, ReadAccessRoles +from models.models.project import FullWriteAccessRoles, ReadAccessRoles class AssayLayer(BaseLayer): @@ -18,7 +18,7 @@ def __init__(self, connection: Connection): self.sampt: SampleTable = SampleTable(connection) # GET - async def query(self, filter_: AssayFilter = None, check_project_id=True): + async def query(self, filter_: AssayFilter = None): """Query for samples""" projects, assays = await self.seqt.query(filter_) @@ -26,23 +26,19 @@ async def query(self, filter_: AssayFilter = None, check_project_id=True): if not assays: return [] - if check_project_id: - await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + project_ids=projects, allowed_roles=ReadAccessRoles + ) return assays - async def get_assay_by_id( - self, assay_id: int, check_project_id=True - ) -> AssayInternal: + async def get_assay_by_id(self, assay_id: int) -> AssayInternal: """Get assay by ID""" project, assay = await self.seqt.get_assay_by_id(assay_id) - if check_project_id: - await self.ptable.check_access_to_project_id( - self.author, project, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return assay @@ -60,7 +56,7 @@ async def get_assay_by_external_id( return assay async def get_assays_for_sequencing_group_ids( - self, sequencing_group_ids: list[int], check_project_ids: bool = True + self, sequencing_group_ids: list[int] ) -> dict[int, list[AssayInternal]]: """Get assays for a list of sequencing group IDs""" projects, assays = await self.seqt.get_assays_for_sequencing_group_ids( @@ -70,10 +66,9 @@ async def get_assays_for_sequencing_group_ids( if not assays: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return assays @@ -105,19 +100,16 @@ async def get_assays_by( active=active, ) - if not project_ids: - # if we didn't specify a project, we need to check access - # to the projects we got back - await self.ptable.check_access_to_project_ids( - self.author, projs, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projs, allowed_roles=ReadAccessRoles + ) return seqs # region UPSERTs async def upsert_assay( - self, assay: AssayUpsertInternal, check_project_id=True, open_transaction=True + self, assay: AssayUpsertInternal, open_transaction=True ) -> AssayUpsertInternal: """Upsert a single assay""" @@ -128,8 +120,9 @@ async def upsert_assay( project_ids = await self.sampt.get_project_ids_for_sample_ids( [assay.sample_id] ) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=FullWriteAccessRoles + + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles ) seq_id = await self.seqt.insert_assay( @@ -141,12 +134,12 @@ async def upsert_assay( ) assay.id = seq_id else: - if check_project_id: - # can check the project id of the assay we're updating - project_ids = await self.seqt.get_projects_by_assay_ids([assay.id]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=FullWriteAccessRoles - ) + # can check the project id of the assay we're updating + project_ids = await self.seqt.get_projects_by_assay_ids([assay.id]) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) + # Otherwise update await self.seqt.update_assay( assay.id, @@ -161,27 +154,24 @@ async def upsert_assay( async def upsert_assays( self, assays: list[AssayUpsertInternal], - check_project_ids: bool = True, open_transaction=True, ) -> list[AssayUpsertInternal]: """Upsert multiple sequences to the given sample (sid)""" - if check_project_ids: - sample_ids = set(s.sample_id for s in assays) - st = SampleTable(self.connection) - project_ids = await st.get_project_ids_for_sample_ids(list(sample_ids)) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=FullWriteAccessRoles - ) + sample_ids = set(s.sample_id for s in assays) + st = SampleTable(self.connection) + project_ids = await st.get_project_ids_for_sample_ids(list(sample_ids)) + + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) with_function = ( self.connection.connection.transaction if open_transaction else NoOpAenter ) async with with_function(): for a in assays: - await self.upsert_assay( - a, check_project_id=False, open_transaction=False - ) + await self.upsert_assay(a, open_transaction=False) return assays diff --git a/db/python/layers/audit_log.py b/db/python/layers/audit_log.py index 406408b1c..b3f86ad81 100644 --- a/db/python/layers/audit_log.py +++ b/db/python/layers/audit_log.py @@ -1,7 +1,7 @@ from db.python.layers.base import BaseLayer, Connection from db.python.tables.audit_log import AuditLogTable from models.models.audit_log import AuditLogId, AuditLogInternal -from models.models.group import ReadAccessRoles +from models.models.project import ReadAccessRoles class AuditLogLayer(BaseLayer): @@ -12,19 +12,16 @@ def __init__(self, connection: Connection): self.alayer: AuditLogTable = AuditLogTable(connection) # GET - async def get_for_ids( - self, ids: list[AuditLogId], check_project_id: bool = True - ) -> list[AuditLogInternal]: + async def get_for_ids(self, ids: list[AuditLogId]) -> list[AuditLogInternal]: """Query for samples""" if not ids: return [] logs = await self.alayer.get_audit_logs_for_ids(ids) - if check_project_id: - projects = {log.auth_project for log in logs} - await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, allowed_roles=ReadAccessRoles - ) + projects = {log.auth_project for log in logs} + self.connection.check_access_to_projects_for_ids( + project_ids=projects, allowed_roles=ReadAccessRoles + ) return logs diff --git a/db/python/layers/base.py b/db/python/layers/base.py index 28c7575a7..edb893c52 100644 --- a/db/python/layers/base.py +++ b/db/python/layers/base.py @@ -1,4 +1,4 @@ -from db.python.tables.project import Connection, ProjectPermissionsTable +from db.python.connect import Connection class BaseLayer: @@ -6,7 +6,6 @@ class BaseLayer: def __init__(self, connection: Connection): self.connection = connection - self.ptable = ProjectPermissionsTable(connection) @property def author(self): diff --git a/db/python/layers/cohort.py b/db/python/layers/cohort.py index f66deddc9..09c24cc77 100644 --- a/db/python/layers/cohort.py +++ b/db/python/layers/cohort.py @@ -2,7 +2,6 @@ from db.python.layers.base import BaseLayer from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.cohort import CohortFilter, CohortTable, CohortTemplateFilter -from db.python.tables.project import ProjectId, ProjectPermissionsTable from db.python.tables.sample import SampleFilter, SampleTable from db.python.tables.sequencing_group import ( SequencingGroupFilter, @@ -15,6 +14,7 @@ CohortTemplateInternal, NewCohortInternal, ) +from models.models.project import ProjectId, ReadAccessRoles logger = get_logger() @@ -62,29 +62,23 @@ def __init__(self, connection: Connection): self.sampt = SampleTable(connection) self.ct = CohortTable(connection) - self.pt = ProjectPermissionsTable(connection) self.sgt = SequencingGroupTable(connection) self.sglayer = SequencingGroupLayer(self.connection) - async def query( - self, filter_: CohortFilter, check_project_ids: bool = True - ) -> list[CohortInternal]: + async def query(self, filter_: CohortFilter) -> list[CohortInternal]: """Query Cohorts""" cohorts, project_ids = await self.ct.query(filter_) if not cohorts: return [] - if check_project_ids: - await self.pt.get_and_check_access_to_projects_for_ids( - user=self.connection.author, - project_ids=list(project_ids), - readonly=True, - ) + self.connection.check_access_to_projects_for_ids( + project_ids=list(project_ids), allowed_roles=ReadAccessRoles + ) return cohorts async def query_cohort_templates( - self, filter_: CohortTemplateFilter, check_project_ids: bool = True + self, filter_: CohortTemplateFilter ) -> list[CohortTemplateInternal]: """Query CohortTemplates""" project_ids, cohort_templates = await self.ct.query_cohort_templates(filter_) @@ -92,12 +86,10 @@ async def query_cohort_templates( if not cohort_templates: return [] - if check_project_ids: - await self.pt.get_and_check_access_to_projects_for_ids( - user=self.connection.author, - project_ids=list(project_ids), - readonly=True, - ) + self.connection.check_access_to_projects_for_ids( + project_ids=list(project_ids), allowed_roles=ReadAccessRoles + ) + return cohort_templates async def get_template_by_cohort_id(self, cohort_id: int) -> CohortTemplateInternal: diff --git a/db/python/layers/family.py b/db/python/layers/family.py index 56fb12942..4d345e941 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -13,9 +13,8 @@ from db.python.tables.sample import SampleTable from db.python.utils import GenericFilter, NotFoundError from models.models.family import FamilyInternal, PedRow, PedRowInternal -from models.models.group import FullWriteAccessRoles, ReadAccessRoles from models.models.participant import ParticipantUpsertInternal -from models.models.project import ProjectId +from models.models.project import ProjectId, ReadAccessRoles class FamilyLayer(BaseLayer): @@ -37,9 +36,7 @@ async def create_family( coded_phenotype=coded_phenotype, ) - async def get_family_by_internal_id( - self, family_id: int, check_project_id: bool = True - ) -> FamilyInternal: + async def get_family_by_internal_id(self, family_id: int) -> FamilyInternal: """Get family by internal ID""" projects, families = await self.ftable.query( FamilyFilter(id=GenericFilter(eq=family_id)) @@ -47,10 +44,10 @@ async def get_family_by_internal_id( if not families: raise NotFoundError(f'Family with ID {family_id} not found') family = families[0] - if check_project_id: - await self.ptable.check_access_to_project_ids( - self.author, [project], allowed_roles=ReadAccessRoles - ) + + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return family @@ -61,7 +58,7 @@ async def get_family_by_external_id( families = await self.ftable.query( FamilyFilter( external_id=GenericFilter(eq=external_id), - project=GenericFilter(eq=project or self.connection.project), + project=GenericFilter(eq=project or self.connection.project_id), ) ) if not families: @@ -72,7 +69,6 @@ async def get_family_by_external_id( async def query( self, filter_: FamilyFilter, - check_project_ids: bool = True, ) -> list[FamilyInternal]: """Get all families for a project""" @@ -80,10 +76,9 @@ async def query( projects, families = await self.ftable.query(filter_) - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return families @@ -91,7 +86,6 @@ async def get_families_by_ids( self, family_ids: list[int], check_missing: bool = True, - check_project_ids: bool = True, ) -> list[FamilyInternal]: """Get families by internal IDs""" projects, families = await self.ftable.query( @@ -100,10 +94,9 @@ async def get_families_by_ids( if not families: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) if check_missing and len(family_ids) != len(families): missing_ids = set(family_ids) - set(f.id for f in families) @@ -112,7 +105,7 @@ async def get_families_by_ids( return families async def get_families_by_participants( - self, participant_ids: list[int], check_project_ids: bool = True + self, participant_ids: list[int] ) -> dict[int, list[FamilyInternal]]: """ Get families keyed by participant_ids, this will duplicate families @@ -123,10 +116,9 @@ async def get_families_by_participants( if not participant_map: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return participant_map @@ -136,14 +128,13 @@ async def update_family( external_id: str = None, description: str = None, coded_phenotype: str = None, - check_project_ids: bool = True, ) -> bool: """Update fields on some family""" - if check_project_ids: - project_ids = await self.ftable.get_projects_by_family_ids([id_]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=FullWriteAccessRoles - ) + project_ids = await self.ftable.get_projects_by_family_ids([id_]) + + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=ReadAccessRoles + ) return await self.ftable.update_family( id_=id_, @@ -214,9 +205,7 @@ async def get_pedigree( return mapped_rows - async def get_participant_family_map( - self, participant_ids: list[int], check_project_ids=False - ): + async def get_participant_family_map(self, participant_ids: list[int]): """Get participant family map""" fptable = FamilyParticipantTable(self.connection) @@ -224,8 +213,9 @@ async def get_participant_family_map( participant_ids=participant_ids ) - if check_project_ids: - raise NotImplementedError(f'Must check specified projects: {projects}') + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return family_map @@ -277,7 +267,7 @@ async def import_pedigree( external_family_id_map = await self.ftable.get_id_map_by_external_ids( list(external_family_ids), - project=self.connection.project, + project=self.connection.project_id, allow_missing=True, ) missing_external_family_ids = [ @@ -285,7 +275,7 @@ async def import_pedigree( ] external_participant_ids_map = await participant_table.get_id_map_by_external_ids( list(external_participant_ids), - project=self.connection.project, + project=self.connection.project_id, # Allow missing participants if we're creating them allow_missing=create_missing_participants, ) @@ -418,7 +408,7 @@ def select_columns( return True async def get_family_participants_by_family_ids( - self, family_ids: list[int], check_project_ids: bool = True + self, family_ids: list[int] ) -> dict[int, list[PedRowInternal]]: """Get family participants for family IDs""" projects, fps = await self.fptable.query( @@ -428,15 +418,14 @@ async def get_family_participants_by_family_ids( if not fps: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return group_by(fps, lambda r: r.family_id) async def get_family_participants_for_participants( - self, participant_ids: list[int], check_project_ids: bool = True + self, participant_ids: list[int] ) -> list[PedRowInternal]: """Get family participants for participant IDs""" projects, fps = await self.fptable.query( @@ -446,9 +435,8 @@ async def get_family_participants_for_participants( if not fps: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return fps diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index a4f529723..1c05b0a71 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Any +from db.python.connect import Connection from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer from db.python.tables.family import FamilyTable @@ -21,9 +22,8 @@ split_generic_terms, ) from models.models.family import PedRowInternal -from models.models.group import FullWriteAccessRoles, ReadAccessRoles from models.models.participant import ParticipantInternal, ParticipantUpsertInternal -from models.models.project import ProjectId +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles HPO_REGEX_MATCHER = re.compile(r'HP\:\d+$') @@ -239,14 +239,13 @@ def process_hpo_term(term): class ParticipantLayer(BaseLayer): """Layer for more complex sample logic""" - def __init__(self, connection): + def __init__(self, connection: Connection): super().__init__(connection) self.pttable = ParticipantTable(connection=connection) async def get_participants_by_ids( self, pids: list[int], - check_project_ids: bool = True, allow_missing: bool = False, ) -> list[ParticipantInternal]: """ @@ -257,10 +256,9 @@ async def get_participants_by_ids( if not participants: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) if not allow_missing and len(participants) != len(pids): # participants are missing @@ -298,7 +296,7 @@ async def fill_in_missing_participants(self): # { external_id: internal_id } samples_with_no_pid = ( await sample_table.get_samples_with_missing_participants_by_internal_id( - project=self.connection.project + project=self.connection.project_id ) ) external_sample_map_with_no_pid = { @@ -308,7 +306,7 @@ async def fill_in_missing_participants(self): unlinked_participants = await self.get_id_map_by_external_ids( list(external_sample_map_with_no_pid.keys()), - project=self.connection.project, + project=self.connection.project_id, allow_missing=True, ) @@ -399,10 +397,10 @@ async def generic_individual_metadata_importer( allow_missing_participants = ( extra_participants_method != ExtraParticipantImporterHandler.FAIL ) - assert self.connection.project + assert self.connection.project_id external_pid_map = await self.get_id_map_by_external_ids( list(external_participant_ids), - project=self.connection.project, + project=self.connection.project_id, allow_missing=allow_missing_participants, ) if extra_participants_method == ExtraParticipantImporterHandler.ADD: @@ -416,7 +414,7 @@ async def generic_individual_metadata_importer( reported_gender=None, karyotype=None, meta=None, - project=self.connection.project, + project=self.connection.project_id, ) elif extra_participants_method == ExtraParticipantImporterHandler.IGNORE: rows = [ @@ -449,7 +447,7 @@ async def generic_individual_metadata_importer( fmap_by_internal = await ftable.get_id_map_by_internal_ids(list(fids)) fmap_from_external = await ftable.get_id_map_by_external_ids( list(external_family_ids), - project=self.connection.project, + project=self.connection.project_id, allow_missing=True, ) fmap_by_external = { @@ -522,7 +520,7 @@ async def generic_individual_metadata_importer( return True async def get_participants_by_families( - self, family_ids: list[int], check_project_ids: bool = True + self, family_ids: list[int] ) -> dict[int, list[ParticipantInternal]]: """Get participants, keyed by family ID""" projects, family_map = await self.pttable.get_participants_by_families( @@ -531,10 +529,9 @@ async def get_participants_by_families( if not family_map: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return family_map @@ -582,7 +579,6 @@ async def upsert_participant( self, participant: ParticipantUpsertInternal, project: ProjectId = None, - check_project_id: bool = True, open_transaction=True, ) -> ParticipantUpsertInternal: """Create a single participant""" @@ -594,18 +590,13 @@ async def upsert_participant( async with with_function(): if participant.id: - if check_project_id: - project_ids = ( - await self.pttable.get_project_ids_for_participant_ids( - [participant.id] - ) - ) + project_ids = await self.pttable.get_project_ids_for_participant_ids( + [participant.id] + ) - await self.ptable.check_access_to_project_ids( - self.connection.author, - project_ids, - allowed_roles=FullWriteAccessRoles, - ) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) await self.pttable.update_participant( participant_id=participant.id, external_id=participant.external_id, @@ -633,7 +624,6 @@ async def upsert_participant( await slayer.upsert_samples( participant.samples, project=project, - check_project_id=False, open_transaction=False, ) @@ -659,18 +649,15 @@ async def upsert_participants( return participants async def update_many_participant_external_ids( - self, internal_to_external_id: dict[int, str], check_project_ids=True + self, internal_to_external_id: dict[int, str] ): """Update many participant external ids""" - if check_project_ids: - projects = await self.pttable.get_project_ids_for_participant_ids( - list(internal_to_external_id.keys()) - ) - await self.ptable.check_access_to_project_ids( - user=self.author, - project_ids=projects, - allowed_roles=FullWriteAccessRoles, - ) + projects = await self.pttable.get_project_ids_for_participant_ids( + list(internal_to_external_id.keys()) + ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) return await self.pttable.update_many_participant_external_ids( internal_to_external_id @@ -710,12 +697,12 @@ async def get_seqr_individual_template( internal_to_external_fid_map = {} if external_participant_ids or internal_participant_ids: - assert self.connection.project + assert self.connection.project_id pids = set(internal_participant_ids or []) if external_participant_ids: pid_map = await self.get_id_map_by_external_ids( external_participant_ids, - project=self.connection.project, + project=self.connection.project_id, allow_missing=False, ) pids |= set(pid_map.values()) @@ -783,7 +770,7 @@ async def get_seqr_individual_template( } async def get_family_participant_data( - self, family_id: int, participant_id: int, check_project_ids: bool = True + self, family_id: int, participant_id: int ) -> PedRowInternal: """Gets the family_participant row for a specific participant""" fptable = FamilyParticipantTable(self.connection) @@ -799,10 +786,10 @@ async def get_family_participant_data( f'Family participant row (family_id: {family_id}, ' f'participant_id: {participant_id}) not found' ) - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return rows[0] @@ -952,10 +939,9 @@ async def check_project_access_for_participants_families( ) ftable = FamilyTable(self.connection) fprojects = await ftable.get_projects_by_family_ids(family_ids=family_ids) - return await self.ptable.check_access_to_project_ids( - self.connection.author, - list(pprojects | fprojects), - allowed_roles=ReadAccessRoles, + + return self.connection.check_access_to_projects_for_ids( + list(pprojects | fprojects), allowed_roles=ReadAccessRoles ) async def update_participant_family( diff --git a/db/python/layers/sample.py b/db/python/layers/sample.py index 7d805284c..92c61eae0 100644 --- a/db/python/layers/sample.py +++ b/db/python/layers/sample.py @@ -6,11 +6,9 @@ from db.python.layers.base import BaseLayer, Connection from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.assay import NoOpAenter -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter, SampleTable from db.python.utils import GenericFilter, NotFoundError -from models.models.group import FullWriteAccessRoles, ReadAccessRoles -from models.models.project import ProjectId +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles from models.models.sample import SampleInternal, SampleUpsertInternal from models.utils.sample_id_format import sample_id_format_list @@ -21,37 +19,33 @@ class SampleLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) self.st: SampleTable = SampleTable(connection) - self.pt = ProjectPermissionsTable(connection) self.connection = connection # GETS - async def get_by_id(self, sample_id: int, check_project_id=True) -> SampleInternal: + async def get_by_id(self, sample_id: int) -> SampleInternal: """Get sample by internal sample id""" project, sample = await self.st.get_sample_by_id(sample_id) - if check_project_id: - await self.pt.check_access_to_project_ids( - self.connection.author, [project], allowed_roles=ReadAccessRoles - ) + + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return sample - async def query( - self, filter_: SampleFilter, check_project_ids: bool = True - ) -> list[SampleInternal]: + async def query(self, filter_: SampleFilter) -> list[SampleInternal]: """Query samples""" projects, samples = await self.st.query(filter_) if not samples: return samples - if check_project_ids: - await self.pt.check_access_to_project_ids( - self.connection.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return samples async def get_samples_by_participants( - self, participant_ids: list[int], check_project_ids: bool = True + self, participant_ids: list[int] ) -> dict[int, list[SampleInternal]]: """Get map of samples by participants""" @@ -64,10 +58,9 @@ async def get_samples_by_participants( if not samples: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) grouped_samples = group_by(samples, lambda s: s.participant_id) @@ -77,15 +70,12 @@ async def get_project_ids_for_sample_ids(self, sample_ids: list[int]) -> set[int """Return the projects associated with the sample ids""" return await self.st.get_project_ids_for_sample_ids(sample_ids) - async def get_sample_by_id( - self, sample_id: int, check_project_id=True - ) -> SampleInternal: + async def get_sample_by_id(self, sample_id: int) -> SampleInternal: """Get sample by ID""" project, sample = await self.st.get_sample_by_id(sample_id) - if check_project_id: - await self.pt.check_access_to_project_ids( - self.author, [project], allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return sample @@ -105,7 +95,7 @@ async def get_sample_id_map_by_external_ids( ) -> dict[str, int]: """Get map of samples {external_id: internal_id}""" external_ids_set = set(external_ids) - _project = project or self.connection.project + _project = project or self.connection.project_id assert _project sample_id_map = await self.st.get_sample_id_map_by_external_ids( external_ids=list(external_ids_set), project=_project @@ -122,7 +112,7 @@ async def get_sample_id_map_by_external_ids( ) async def get_internal_to_external_sample_id_map( - self, sample_ids: list[int], check_project_ids=True, allow_missing=False + self, sample_ids: list[int], allow_missing=False ) -> dict[int, str]: """Get map of internal sample id to external id""" @@ -147,10 +137,9 @@ async def get_internal_to_external_sample_id_map( if not sample_id_map: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return sample_id_map @@ -167,18 +156,17 @@ async def get_samples_by( participant_ids: list[int] | None = None, project_ids=None, active=True, - check_project_ids=True, ) -> list[SampleInternal]: """Get samples by some criteria""" if not sample_ids and not project_ids: raise ValueError('Must specify one of "project_ids" or "sample_ids"') - if sample_ids and check_project_ids: + if sample_ids: # project_ids were already checked when transformed to ints, # so no else required pjcts = await self.st.get_project_ids_for_sample_ids(sample_ids) - await self.ptable.check_access_to_project_ids( - self.author, pjcts, allowed_roles=ReadAccessRoles + self.connection.check_access_to_projects_for_ids( + pjcts, allowed_roles=ReadAccessRoles ) _returned_project_ids, samples = await self.st.query( @@ -193,10 +181,9 @@ async def get_samples_by( if not samples: return [] - if not project_ids and check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, _returned_project_ids, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + _returned_project_ids, allowed_roles=ReadAccessRoles + ) return samples @@ -212,8 +199,8 @@ async def get_samples_create_date( ) -> dict[int, datetime.date]: """Get a map of {internal_sample_id: date_created} for list of sample_ids""" pjcts = await self.st.get_project_ids_for_sample_ids(sample_ids) - await self.pt.check_access_to_project_ids( - self.author, pjcts, allowed_roles=ReadAccessRoles + self.connection.check_access_to_projects_for_ids( + pjcts, allowed_roles=ReadAccessRoles ) return await self.st.get_samples_create_date(sample_ids) @@ -276,7 +263,6 @@ async def upsert_samples( samples: list[SampleUpsertInternal], open_transaction: bool = True, project: ProjectId = None, - check_project_id=True, ) -> list[SampleUpsertInternal]: """Batch upsert a list of samples with sequences""" seqglayer: SequencingGroupLayer = SequencingGroupLayer(self.connection) @@ -285,13 +271,12 @@ async def upsert_samples( self.connection.connection.transaction if open_transaction else NoOpAenter ) - if check_project_id: - sids = [s.id for s in samples if s.id] - if sids: - pjcts = await self.st.get_project_ids_for_sample_ids(sids) - await self.ptable.check_access_to_project_ids( - self.author, pjcts, allowed_roles=FullWriteAccessRoles - ) + sids = [s.id for s in samples if s.id] + if sids: + pjcts = await self.st.get_project_ids_for_sample_ids(sids) + self.connection.check_access_to_projects_for_ids( + pjcts, allowed_roles=ReadAccessRoles + ) async with with_function(): # Create or update samples @@ -326,16 +311,12 @@ async def merge_samples( self, id_keep: int, id_merge: int, - check_project_id=True, ): """Merge two samples into one another""" - if check_project_id: - projects = await self.st.get_project_ids_for_sample_ids([id_keep, id_merge]) - await self.ptable.check_access_to_project_ids( - user=self.author, - project_ids=projects, - allowed_roles=FullWriteAccessRoles, - ) + projects = await self.st.get_project_ids_for_sample_ids([id_keep, id_merge]) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) return await self.st.merge_samples( id_keep=id_keep, @@ -343,7 +324,7 @@ async def merge_samples( ) async def update_many_participant_ids( - self, ids: list[int], participant_ids: list[int], check_sample_ids=True + self, ids: list[int], participant_ids: list[int] ) -> bool: """ Update participant IDs for many samples @@ -353,27 +334,24 @@ async def update_many_participant_ids( raise ValueError( f'Number of sampleIDs ({len(ids)}) and ParticipantIds ({len(participant_ids)}) did not match' ) - if check_sample_ids: - project_ids = await self.st.get_project_ids_for_sample_ids(ids) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=FullWriteAccessRoles - ) + + projects = await self.st.get_project_ids_for_sample_ids(ids) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) await self.st.update_many_participant_ids( ids=ids, participant_ids=participant_ids ) return True - async def get_history_of_sample( - self, id_: int, check_sample_ids: bool = True - ) -> list[SampleInternal]: + async def get_history_of_sample(self, id_: int) -> list[SampleInternal]: """Get the full history of a sample""" rows = await self.st.get_history_of_sample(id_) - if check_sample_ids: - project_ids = set(r.project for r in rows) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, allowed_roles=ReadAccessRoles - ) + projects = set(r.project for r in rows) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) return rows diff --git a/db/python/layers/search.py b/db/python/layers/search.py index 98187250d..6223a236a 100644 --- a/db/python/layers/search.py +++ b/db/python/layers/search.py @@ -4,7 +4,6 @@ from db.python.layers.base import BaseLayer, Connection from db.python.tables.family import FamilyTable from db.python.tables.participant import ParticipantTable -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleTable from db.python.tables.sequencing_group import SequencingGroupTable from db.python.utils import NotFoundError @@ -28,7 +27,6 @@ class SearchLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) - self.pt = ProjectPermissionsTable(connection) self.connection = connection @staticmethod @@ -190,12 +188,12 @@ async def search(self, query: str, project_ids: list[int]) -> List[SearchRespons data=SampleSearchResponseData( project=project, id=sample_id_format(s_id), - family_external_ids=participant_family_eids.get(p_id) or [] - if p_id - else [], - participant_external_ids=sample_participant_eids.get(p_id) or [] - if p_id - else [], + family_external_ids=( + participant_family_eids.get(p_id) or [] if p_id else [] + ), + participant_external_ids=( + sample_participant_eids.get(p_id) or [] if p_id else [] + ), sample_external_ids=[s_eid], ), ) diff --git a/db/python/layers/seqr.py b/db/python/layers/seqr.py index f34cff4eb..65fa6adde 100644 --- a/db/python/layers/seqr.py +++ b/db/python/layers/seqr.py @@ -38,7 +38,6 @@ # literally the most temporary thing ever, but for complete # automation need to have sample inclusion / exclusion -from models.models.group import ReadAccessRoles from models.utils.sequencing_group_id_format import ( sequencing_group_id_format, sequencing_group_id_format_list, @@ -135,14 +134,13 @@ async def sync_dataset( raise ValueError('Seqr synchronisation is not configured in metamist') token = self.generate_seqr_auth_token() - project = await self.ptable.get_and_check_access_to_project_for_id( - self.connection.author, - project_id=self.connection.project, - allowed_roles=ReadAccessRoles, - ) + project = self.connection.project + assert project - seqr_guid = project.meta.get( - self.get_meta_key_from_sequencing_type(sequencing_type) + seqr_guid = ( + project.meta.get(self.get_meta_key_from_sequencing_type(sequencing_type)) + if project.meta + else None ) if not seqr_guid: @@ -369,7 +367,9 @@ async def sync_individual_metadata( f'Uploaded individual metadata for {len(processed_records)} individuals' ] - def check_updated_sequencing_group_ids(self, sequencing_group_ids: set[int], es_index_analyses: list[AnalysisInternal]): + def check_updated_sequencing_group_ids( + self, sequencing_group_ids: set[int], es_index_analyses: list[AnalysisInternal] + ): """Check if the sequencing group IDs have been updated""" messages = [] if sequencing_group_ids: @@ -389,7 +389,8 @@ def check_updated_sequencing_group_ids(self, sequencing_group_ids: set[int], es_ ) if sequencing_groups_diff: messages.append( - f'Sequencing groups added to {es_index_analyses[-1].output}: ' + ', '.join(sequencing_groups_diff), + f'Sequencing groups added to {es_index_analyses[-1].output}: ' + + ', '.join(sequencing_groups_diff), ) sg_ids_missing_from_index = sequencing_group_id_format_list( @@ -402,7 +403,13 @@ def check_updated_sequencing_group_ids(self, sequencing_group_ids: set[int], es_ ) return messages - async def post_es_index_update(self, session: aiohttp.ClientSession, url: str, post_json: dict, headers: dict[str, str]): + async def post_es_index_update( + self, + session: aiohttp.ClientSession, + url: str, + post_json: dict, + headers: dict[str, str], + ): """Post request to update ES index""" resp = await session.post( url=url, @@ -422,8 +429,9 @@ async def update_es_index( sequencing_group_ids: set[int], ) -> list[str]: """Update seqr samples for latest elastic-search index""" + assert self.connection.project_id eid_to_sgid_rows = await self.player.get_external_participant_id_to_internal_sequencing_group_id_map( - self.connection.project, sequencing_type=sequencing_type + self.connection.project_id, sequencing_type=sequencing_type ) # format sample ID for transport @@ -456,7 +464,7 @@ async def update_es_index( alayer = AnalysisLayer(connection=self.connection) es_index_analyses = await alayer.query( AnalysisFilter( - project=GenericFilter(eq=self.connection.project), + project=GenericFilter(eq=self.connection.project_id), type=GenericFilter(eq='es-index'), status=GenericFilter(eq=AnalysisStatus.COMPLETED), meta={ @@ -486,7 +494,11 @@ async def update_es_index( es_index = es_indexes_filtered_by_type[-1].output - messages.extend(self.check_updated_sequencing_group_ids(sequencing_group_ids, es_indexes_filtered_by_type)) + messages.extend( + self.check_updated_sequencing_group_ids( + sequencing_group_ids, es_indexes_filtered_by_type + ) + ) req1_url = SEQR_URL + _url_update_es_index.format(projectGuid=project_guid) post_json = { @@ -495,7 +507,9 @@ async def update_es_index( 'mappingFilePath': fn_path, 'ignoreExtraSamplesInCallset': True, } - requests.append(self.post_es_index_update(session, req1_url, post_json, headers)) + requests.append( + self.post_es_index_update(session, req1_url, post_json, headers) + ) messages.extend(await asyncio.gather(*requests)) return messages @@ -528,8 +542,9 @@ async def sync_cram_map( alayer = AnalysisLayer(self.connection) + assert self.connection.project_id reads_map = await alayer.get_sample_cram_path_map_for_seqr( - project=self.connection.project, + project=self.connection.project_id, sequencing_types=[sequencing_type], participant_ids=participant_ids, ) @@ -605,9 +620,9 @@ async def _make_update_igv_call(update): async def _get_pedigree_from_sm(self, family_ids: set[int]) -> list[dict] | None: """Call get_pedigree and return formatted string with header""" - + assert self.connection.project_id ped_rows = await self.flayer.get_pedigree( - self.connection.project, + self.connection.project_id, family_ids=list(family_ids), replace_with_family_external_ids=True, replace_with_participant_external_ids=True, @@ -664,8 +679,9 @@ async def get_individual_meta_objs_for_seqr( self, participant_ids: list[int] ) -> list[dict] | None: """Get formatted list of dictionaries for syncing individual meta to seqr""" + assert self.connection.project_id individual_metadata_resp = await self.player.get_seqr_individual_template( - self.connection.project, internal_participant_ids=participant_ids + self.connection.project_id, internal_participant_ids=participant_ids ) json_rows: list[dict] = individual_metadata_resp['rows'] diff --git a/db/python/layers/sequencing_group.py b/db/python/layers/sequencing_group.py index c88ad066a..1d91621de 100644 --- a/db/python/layers/sequencing_group.py +++ b/db/python/layers/sequencing_group.py @@ -10,8 +10,7 @@ SequencingGroupTable, ) from db.python.utils import NotFoundError -from models.models.group import ReadAccessRoles -from models.models.project import ProjectId +from models.models.project import ProjectId, ReadAccessRoles from models.models.sequencing_group import ( SequencingGroupInternal, SequencingGroupInternalId, @@ -29,19 +28,17 @@ def __init__(self, connection: Connection): self.sampt: SampleTable = SampleTable(connection) async def get_sequencing_group_by_id( - self, sequencing_group_id: int, check_project_id: bool = True + self, sequencing_group_id: int ) -> SequencingGroupInternal: """ Get sequencing group by internal ID """ - groups = await self.get_sequencing_groups_by_ids( - [sequencing_group_id], check_project_ids=check_project_id - ) + groups = await self.get_sequencing_groups_by_ids([sequencing_group_id]) return groups[0] async def get_sequencing_groups_by_ids( - self, sequencing_group_ids: list[int], check_project_ids: bool = True + self, sequencing_group_ids: list[int] ) -> list[SequencingGroupInternal]: """ Get sequence groups by internal IDs @@ -53,10 +50,9 @@ async def get_sequencing_groups_by_ids( sequencing_group_ids ) - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) if len(groups) != len(sequencing_group_ids): missing_ids = set(sequencing_group_ids) - set(sg.id for sg in groups) @@ -68,7 +64,7 @@ async def get_sequencing_groups_by_ids( return groups async def get_sequencing_groups_by_analysis_ids( - self, analysis_ids: list[int], check_project_ids: bool = True + self, analysis_ids: list[int] ) -> dict[int, list[SequencingGroupInternal]]: """ Get sequencing groups by analysis IDs @@ -83,17 +79,15 @@ async def get_sequencing_groups_by_analysis_ids( if not groups: return groups - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return groups async def query( self, filter_: SequencingGroupFilter, - check_project_ids: bool = True, ) -> list[SequencingGroupInternal]: """ Query sequencing groups @@ -102,10 +96,9 @@ async def query( if not sequencing_groups: return [] - if check_project_ids and not (filter_.project and filter_.project.in_): - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return sequencing_groups @@ -138,7 +131,7 @@ async def get_all_sequencing_group_ids_by_sample_ids_by_type( return await self.seqgt.get_all_sequencing_group_ids_by_sample_ids_by_type() async def get_participant_ids_sequencing_group_ids_for_sequencing_type( - self, sequencing_type: str, check_project_ids: bool = True + self, sequencing_type: str ) -> dict[int, list[int]]: """ Get list of partiicpant IDs for a specific sequence type, @@ -153,10 +146,9 @@ async def get_participant_ids_sequencing_group_ids_for_sequencing_type( if not pids: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, allowed_roles=ReadAccessRoles - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return pids diff --git a/db/python/layers/web.py b/db/python/layers/web.py index 981edfebb..2f4083764 100644 --- a/db/python/layers/web.py +++ b/db/python/layers/web.py @@ -13,7 +13,6 @@ from db.python.tables.analysis import AnalysisTable from db.python.tables.assay import AssayTable from db.python.tables.base import DbBase -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sequencing_group import SequencingGroupTable from db.python.utils import escape_like_term from models.models import ( @@ -25,7 +24,6 @@ SearchItem, parse_sql_bool, ) -from models.models.group import ReadAccessRoles from models.models.web import ProjectSummaryInternal, WebProject @@ -56,7 +54,7 @@ def _project_summary_sample_query(self, grid_filter: list[SearchItem]): Get query for getting list of samples """ wheres = ['s.project = :project', 's.active'] - values = {'project': self.project} + values = {'project': self.project_id} where_str = '' for query in grid_filter: value = query.query @@ -201,12 +199,12 @@ def _project_summary_process_sample_rows( async def get_total_number_of_samples(self): """Get total number of active samples within a project""" _query = 'SELECT COUNT(*) FROM sample WHERE project = :project AND active' - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) async def get_total_number_of_participants(self): """Get total number of participants within a project""" _query = 'SELECT COUNT(*) FROM participant WHERE project = :project' - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) async def get_total_number_of_sequencing_groups(self): """Get total number of sequencing groups within a project""" @@ -215,7 +213,7 @@ async def get_total_number_of_sequencing_groups(self): FROM sequencing_group sg INNER JOIN sample s ON s.id = sg.sample_id WHERE project = :project AND NOT sg.archived""" - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) async def get_total_number_of_assays(self): """Get total number of sequences within a project""" @@ -224,7 +222,7 @@ async def get_total_number_of_assays(self): FROM assay sq INNER JOIN sample s ON s.id = sq.sample_id WHERE s.project = :project""" - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) @staticmethod def _project_summary_process_family_rows_by_pid( @@ -281,10 +279,10 @@ async def get_project_summary( # do initial query to get sample info sampl = SampleLayer(self._connection) sample_query, values = self._project_summary_sample_query(grid_filter) - ptable = ProjectPermissionsTable(self._connection) - project_db = await ptable.get_and_check_access_to_project_for_id( - self.author, self.project, allowed_roles=ReadAccessRoles - ) + + project_db = self.project + assert project_db + project = WebProject( id=project_db.id, name=project_db.name, @@ -384,10 +382,12 @@ async def get_project_summary( self.get_total_number_of_participants(), self.get_total_number_of_sequencing_groups(), self.get_total_number_of_assays(), - atable.get_number_of_crams_by_sequencing_type(project=self.project), - sgtable.get_type_numbers_for_project(project=self.project), - seqtable.get_assay_type_numbers_by_batch_for_project(project=self.project), - atable.get_seqr_stats_by_sequencing_type(project=self.project), + atable.get_number_of_crams_by_sequencing_type(project=self.project_id), + sgtable.get_type_numbers_for_project(project=self.project_id), + seqtable.get_assay_type_numbers_by_batch_for_project( + project=self.project_id + ), + atable.get_seqr_stats_by_sequencing_type(project=self.project_id), SeqrLayer(self._connection).get_synchronisable_types(project_db), ) @@ -435,7 +435,6 @@ async def get_project_summary( reported_sex=None, reported_gender=None, karyotype=None, - # project=self.project, ) ) elif pid not in pid_seen: @@ -451,7 +450,6 @@ async def get_project_summary( reported_sex=p['reported_sex'], reported_gender=p['reported_gender'], karyotype=p['karyotype'], - # project=self.project, ) ) diff --git a/db/python/tables/analysis.py b/db/python/tables/analysis.py index 8cb6c4512..4b7d0abe6 100644 --- a/db/python/tables/analysis.py +++ b/db/python/tables/analysis.py @@ -76,7 +76,7 @@ async def create_analysis( ('meta', to_db_json(meta or {})), ('output', output), ('audit_log_id', await self.audit_log_id()), - ('project', project or self.project), + ('project', project or self.project_id), ('active', active if active is not None else True), ] @@ -337,7 +337,7 @@ async def get_all_sequencing_group_ids_without_analysis_type( rows = await self.connection.fetch_all( _query, - {'analysis_type': analysis_type, 'project': project or self.project}, + {'analysis_type': analysis_type, 'project': project or self.project_id}, ) return [row[0] for row in rows] @@ -452,10 +452,10 @@ async def get_sample_cram_path_map_for_seqr( async def get_analysis_runner_log( self, - project_ids: List[int] = None, + project_ids: List[int], # author: str = None, - output_dir: str = None, - ar_guid: str = None, + output_dir: str | None = None, + ar_guid: str | None = None, ) -> List[AnalysisInternal]: """ Get log for the analysis-runner, useful for checking this history of analysis @@ -465,9 +465,9 @@ async def get_analysis_runner_log( "type = 'analysis-runner'", 'active', ] - if project_ids: - wheres.append('project in :project_ids') - values['project_ids'] = project_ids + + wheres.append('project in :project_ids') + values['project_ids'] = project_ids if output_dir: wheres.append('(output = :output OR output LIKE :output_like)') diff --git a/db/python/tables/analysis_runner.py b/db/python/tables/analysis_runner.py index 697034d93..dabe294fd 100644 --- a/db/python/tables/analysis_runner.py +++ b/db/python/tables/analysis_runner.py @@ -83,7 +83,7 @@ async def insert_analysis_runner_entry( 'meta': to_db_json(analysis_runner.meta), 'output_path': analysis_runner.output_path, 'audit_log_id': await self.audit_log_id(), - 'project': self.project, + 'project': self.project_id, } await self.connection.execute(_query, values) diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py index 3416f7451..8a9bc4bb3 100644 --- a/db/python/tables/assay.py +++ b/db/python/tables/assay.py @@ -142,12 +142,12 @@ async def get_assay_by_external_id( self, external_sequence_id: str, project: ProjectId | None = None ) -> AssayInternal: """Get assay by EXTERNAL ID""" - if not (project or self.project): + if not (project or self.project_id): raise ValueError('Getting assay by external ID requires a project') f = AssayFilter( external_id=GenericFilter(eq=external_sequence_id), - project=GenericFilter(eq=project or self.project), + project=GenericFilter(eq=project or self.project_id), ) _, assays = await self.query(f) @@ -264,7 +264,7 @@ async def insert_assay( ) if external_ids: - _project = project or self.project + _project = project or self.project_id if not _project: raise ValueError( 'When inserting an external identifier for a sequence, a ' @@ -279,7 +279,7 @@ async def insert_assay( audit_log_id = await self.audit_log_id() eid_values = [ { - 'project': project or self.project, + 'project': project or self.project_id, 'assay_id': id_of_new_assay, 'external_id': eid, 'name': name.lower(), @@ -358,7 +358,7 @@ async def update_assay( await self.connection.execute(_query, fields) if external_ids: - _project = project or self.project + _project = project or self.project_id if not _project: raise ValueError( 'When inserting or updating an external identifier for an ' diff --git a/db/python/tables/base.py b/db/python/tables/base.py index c069bf60a..2da7b5e12 100644 --- a/db/python/tables/base.py +++ b/db/python/tables/base.py @@ -25,6 +25,7 @@ def __init__(self, connection: Connection): self.connection: databases.Database = connection.connection self.author = connection.author self.project = connection.project + self.project_id = connection.project_id if self.author is None: raise InternalError(f'Must provide author to {self.__class__.__name__}') diff --git a/db/python/tables/cohort.py b/db/python/tables/cohort.py index c0e0da24c..f9cf3a5ba 100644 --- a/db/python/tables/cohort.py +++ b/db/python/tables/cohort.py @@ -3,7 +3,6 @@ import datetime from db.python.tables.base import DbBase -from db.python.tables.project import ProjectId from db.python.utils import GenericFilter, GenericFilterModel, NotFoundError, to_db_json from models.models.cohort import ( CohortCriteriaInternal, @@ -11,6 +10,7 @@ CohortTemplateInternal, NewCohortInternal, ) +from models.models.project import ProjectId @dataclasses.dataclass(kw_only=True) diff --git a/db/python/tables/family.py b/db/python/tables/family.py index 4f8aad623..a0de569e5 100644 --- a/db/python/tables/family.py +++ b/db/python/tables/family.py @@ -208,7 +208,7 @@ async def create_family( 'description': description, 'coded_phenotype': coded_phenotype, 'audit_log_id': await self.audit_log_id(), - 'project': project or self.project, + 'project': project or self.project_id, } keys = list(updater.keys()) str_keys = ', '.join(keys) @@ -237,7 +237,7 @@ async def insert_or_update_multiple_families( 'description': descr, 'coded_phenotype': cph, 'audit_log_id': await self.audit_log_id(), - 'project': project or self.project, + 'project': project or self.project_id, } for eid, descr, cph in zip(external_ids, descriptions, coded_phenotypes) ] @@ -271,7 +271,7 @@ async def get_id_map_by_external_ids( _query = 'SELECT external_id, id FROM family WHERE external_id in :external_ids AND project = :project' results = await self.connection.fetch_all( - _query, {'external_ids': family_ids, 'project': project or self.project} + _query, {'external_ids': family_ids, 'project': project or self.project_id} ) id_map = {r['external_id']: r['id'] for r in results} diff --git a/db/python/tables/participant.py b/db/python/tables/participant.py index 60d6c7111..a20f80a91 100644 --- a/db/python/tables/participant.py +++ b/db/python/tables/participant.py @@ -82,7 +82,7 @@ async def create_participant( """ Create a new sample, and add it to database """ - if not (project or self.project): + if not (project or self.project_id): raise ValueError('Must provide project to create participant') _query = """ @@ -102,7 +102,7 @@ async def create_participant( 'karyotype': karyotype, 'meta': to_db_json(meta or {}), 'audit_log_id': await self.audit_log_id(), - 'project': project or self.project, + 'project': project or self.project_id, }, ) @@ -194,7 +194,7 @@ async def get_id_map_by_external_ids( project: ProjectId | None, ) -> dict[str, int]: """Get map of {external_id: internal_participant_id}""" - _project = project or self.project + _project = project or self.project_id if not _project: raise ValueError( 'Must provide project to get participant id map by external' diff --git a/db/python/tables/sample.py b/db/python/tables/sample.py index 10639d914..52323ea67 100644 --- a/db/python/tables/sample.py +++ b/db/python/tables/sample.py @@ -154,7 +154,7 @@ async def insert_sample( ('type', sample_type), ('active', active), ('audit_log_id', await self.audit_log_id()), - ('project', project or self.project), + ('project', project or self.project_id), ] keys = [k for k, _ in kv_pairs] @@ -381,7 +381,8 @@ async def get_sample_id_map_by_external_ids( WHERE external_id in :external_ids AND project = :project """ rows = await self.connection.fetch_all( - _query, {'external_ids': external_ids, 'project': project or self.project} + _query, + {'external_ids': external_ids, 'project': project or self.project_id}, ) sample_id_map = {el[1]: el[0] for el in rows} @@ -406,7 +407,7 @@ async def get_all_sample_id_map_by_internal_ids( """Get sample id map for all samples""" _query = 'SELECT id, external_id FROM sample WHERE project = :project' rows = await self.connection.fetch_all( - _query, {'project': project or self.project} + _query, {'project': project or self.project_id} ) return {el[0]: el[1] for el in rows} @@ -468,6 +469,6 @@ async def get_samples_with_missing_participants_by_internal_id( WHERE participant_id IS NULL AND project = :project """ rows = await self.connection.fetch_all( - _query, {'project': project or self.project} + _query, {'project': project or self.project_id} ) return [SampleInternal.from_db(dict(d)) for d in rows] diff --git a/db/python/tables/sequencing_group.py b/db/python/tables/sequencing_group.py index 3821148a5..6aacc14a2 100644 --- a/db/python/tables/sequencing_group.py +++ b/db/python/tables/sequencing_group.py @@ -265,7 +265,7 @@ async def get_all_sequencing_group_ids_by_sample_ids_by_type( INNER JOIN sequencing_group sg ON s.id = sg.sample_id WHERE project = :project """ - rows = await self.connection.fetch_all(_query, {'project': self.project}) + rows = await self.connection.fetch_all(_query, {'project': self.project_id}) sequencing_group_ids_by_sample_ids_by_type: dict[int, dict[str, list[int]]] = ( defaultdict(lambda: defaultdict(list)) ) @@ -293,7 +293,7 @@ async def get_participant_ids_and_sequencing_group_ids_for_sequencing_type( rows = list( await self.connection.fetch_all( - _query, {'seqtype': sequencing_type, 'project': self.project} + _query, {'seqtype': sequencing_type, 'project': self.project_id} ) ) From 8a779eae9a7e49a02d2b19a7360fc596813fb068 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Mon, 17 Jun 2024 15:43:27 +1000 Subject: [PATCH 11/29] add route to update project members --- api/routes/project.py | 17 +++++++++-- db/python/tables/project.py | 59 +++++++++++++++++++++++++++++++------ models/models/project.py | 2 +- 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/api/routes/project.py b/api/routes/project.py index 799525fd6..b4c315ef2 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from api.utils.db import ( Connection, @@ -6,7 +6,12 @@ get_projectless_db_connection, ) from db.python.tables.project import ProjectMemberWithRole, ProjectPermissionsTable -from models.models.project import FullWriteAccessRoles, Project, ProjectMemberRole +from models.models.project import ( + FullWriteAccessRoles, + Project, + ProjectMemberRole, + project_member_role_names, +) router = APIRouter(prefix='/project', tags=['project']) @@ -105,7 +110,6 @@ async def delete_project_data( success = await ptable.delete_project_data( project_id=connection.project.id, delete_project=delete_project, - author=connection.author, ) return {'success': success} @@ -124,6 +128,13 @@ async def update_project_members( await ptable.check_member_admin_permissions(author=connection.author) assert connection.project + + for member in members: + if member['role'] not in project_member_role_names: + raise HTTPException( + 400, f'Role {member["role"]} is not valid for member {member["member"]}' + ) + await ptable.set_project_members(project=connection.project, members=members) return {'success': True} diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 7387a1c63..98800038b 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -1,12 +1,12 @@ # pylint: disable=global-statement -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Tuple from databases import Database from typing_extensions import TypedDict from api.settings import is_all_access from db.python.utils import Forbidden, NotFoundError, get_logger, to_db_json -from models.models.project import Project, ProjectMemberRole +from models.models.project import Project, project_member_role_names # Avoid circular import for type definition if TYPE_CHECKING: @@ -188,7 +188,6 @@ async def update_project(self, project_name: str, update: dict, author: str): async def delete_project_data(self, project_id: int, delete_project: bool) -> bool: """ Delete data in metamist project, requires project_creator_permissions - Can optionally delete the project also. """ if delete_project: # stop allowing delete project with analysis-runner entries @@ -239,16 +238,58 @@ async def delete_project_data(self, project_id: int, delete_project: bool) -> bo return True async def set_project_members( - self, project: Project, members: list[ProjectMemberRole] + self, project: Project, members: list[ProjectMemberWithRole] ): """ Set group members for a group (by name) """ - print('@TODO') - # group_id = await self.gtable.get_group_name_from_id(group_name) - # await self.gtable.set_group_members( - # group_id, members, audit_log_id=await self.audit_log_id() - # ) + + async with self.connection.transaction(): + + # Get existing rows so that we can keep the existing audit log ids + existing_rows = await self.connection.fetch_all( + """ + select project_id, member, role, audit_log_id + from project_member + where project_id = :project_id + """, + {'project_id': project.id}, + ) + + audit_log_id_map: dict[Tuple[str, str], int | None] = { + (r['member'], r['role']): r['audit_log_id'] for r in existing_rows + } + + # delete existing rows for project + await self.connection.execute( + """ + DELETE FROM project_member + WHERE project_id = :project_id + """, + {'project_id': project.id}, + ) + + new_audit_log_id = await self.audit_log_id() + + await self.connection.execute_many( + """ + INSERT INTO project_member + (project_id, member, role, audit_log_id) + VALUES (:project_id, :member, :role, :audit_log_id); + """, + [ + { + 'project_id': project.id, + 'member': m['member'], + 'role': m['role'], + 'audit_log_id': audit_log_id_map.get( + (m['member'], m['role']), new_audit_log_id + ), + } + for m in members + if m['role'] in project_member_role_names + ], + ) # endregion CREATE / UPDATE diff --git a/models/models/project.py b/models/models/project.py index 53640dd41..a2729a046 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -45,7 +45,7 @@ def is_test(self): """ return self.name == f'{self.dataset}-test' - @field_serializer("roles") + @field_serializer('roles') def serialize_roles(self, roles: set[ProjectMemberRole], _info): return [r.name for r in roles] From 18f5f3239022c55f13820f1a7621a885b147a194 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 18 Jun 2024 16:03:01 +1000 Subject: [PATCH 12/29] update project table methods to incorporate admin group roles Rather than having these roles separately - incorportate them into project level roles where it makes sense. That way the same permission constructs can be used for checking admin roles rather than having to have separate ones. --- api/routes/project.py | 39 +++-- db/python/connect.py | 23 +++ db/python/tables/project.py | 274 +++++++++++------------------------- models/models/project.py | 27 +++- 4 files changed, 153 insertions(+), 210 deletions(-) diff --git a/api/routes/project.py b/api/routes/project.py index b4c315ef2..0a809ecc0 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -5,12 +5,14 @@ get_project_db_connection, get_projectless_db_connection, ) -from db.python.tables.project import ProjectMemberWithRole, ProjectPermissionsTable +from db.python.tables.project import ProjectPermissionsTable from models.models.project import ( FullWriteAccessRoles, Project, ProjectMemberRole, - project_member_role_names, + ProjectMemberUpdate, + ReadAccessRoles, + updatable_project_member_role_names, ) router = APIRouter(prefix='/project', tags=['project']) @@ -25,7 +27,12 @@ async def get_all_projects(connection: Connection = get_projectless_db_connectio @router.get('/', operation_id='getMyProjects', response_model=list[str]) async def get_my_projects(connection: Connection = get_projectless_db_connection): """Get projects I have access to""" - return [p.name for p in connection.all_projects()] + return [ + p.name + for p in connection.projects_with_role( + ReadAccessRoles.union(FullWriteAccessRoles) + ) + ] @router.put('/', operation_id='createProject') @@ -69,12 +76,16 @@ async def get_seqr_projects(connection: Connection = get_projectless_db_connecti @router.post('/{project}/update', operation_id='updateProject') async def update_project( project_update_model: dict, - connection: Connection = get_project_db_connection(FullWriteAccessRoles), + connection: Connection = get_project_db_connection( + {ProjectMemberRole.project_admin} + ), ): """Update a project by project name""" ptable = ProjectPermissionsTable(connection) + # Updating a project additionally requires the project creator permission await ptable.check_project_creator_permissions(author=connection.author) + project = connection.project assert project return await ptable.update_project( @@ -86,7 +97,7 @@ async def update_project( async def delete_project_data( delete_project: bool = False, connection: Connection = get_project_db_connection( - {ProjectMemberRole.writer, ProjectMemberRole.data_manager} + {ProjectMemberRole.project_admin} ), ): """ @@ -117,8 +128,10 @@ async def delete_project_data( @router.patch('/{project}/members', operation_id='updateProjectMembers') async def update_project_members( - members: list[ProjectMemberWithRole], - connection: Connection = get_project_db_connection(FullWriteAccessRoles), + members: list[ProjectMemberUpdate], + connection: Connection = get_project_db_connection( + {ProjectMemberRole.project_member_admin} + ), ): """ Update project members for specific read / write group. @@ -127,14 +140,14 @@ async def update_project_members( ptable = ProjectPermissionsTable(connection) await ptable.check_member_admin_permissions(author=connection.author) - assert connection.project for member in members: - if member['role'] not in project_member_role_names: - raise HTTPException( - 400, f'Role {member["role"]} is not valid for member {member["member"]}' - ) - + for role in member.roles: + if role not in updatable_project_member_role_names: + raise HTTPException( + 400, f'Role {role} is not valid for member {member.member}' + ) + assert connection.project await ptable.set_project_members(project=connection.project, members=members) return {'success': True} diff --git a/db/python/connect.py b/db/python/connect.py index 65d14caf3..2d198c30a 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -85,6 +85,10 @@ def all_projects(self): """Return all projects that the current user has access to""" return list(self.project_id_map.values()) + def projects_with_role(self, allowed_roles: set[ProjectMemberRole]): + """Return all projects that the current user has access to""" + return [p for p in self.project_id_map.values() if p.roles & allowed_roles] + def get_and_check_access_to_projects( self, projects: Iterable[Project], allowed_roles: set[ProjectMemberRole] ): @@ -221,6 +225,25 @@ async def audit_log_id(self): return self._audit_log_id + async def refresh_projects(self): + """ + Re-fetch the projects for the current user and update the connection. + This only really needs to be run after project member updates or project + creation, and really only for tests. The API fetches projects on each request + so subsequent requests after updates will already have up-to-date data + """ + conn = self.connection + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + + project_id_map, project_name_map = await pt.get_projects_accessible_by_user( + user=self.author + ) + self.project_id_map = project_id_map + self.project_name_map = project_name_map + + if self.project_id: + self.project = self.project_id_map.get(self.project_id) + class DatabaseConfiguration(abc.ABC): """Base class for DatabaseConfiguration""" diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 98800038b..4f0f2c3da 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -2,11 +2,13 @@ from typing import TYPE_CHECKING, Any, Tuple from databases import Database -from typing_extensions import TypedDict -from api.settings import is_all_access -from db.python.utils import Forbidden, NotFoundError, get_logger, to_db_json -from models.models.project import Project, project_member_role_names +from db.python.utils import Forbidden, get_logger, to_db_json +from models.models.project import ( + Project, + ProjectMemberUpdate, + project_member_role_names, +) # Avoid circular import for type definition if TYPE_CHECKING: @@ -20,13 +22,6 @@ GROUP_NAME_MEMBERS_ADMIN = 'members-admin' -class ProjectMemberWithRole(TypedDict): - """Dict passed to the update project member endpoint to specify roles for members""" - - member: str - role: str - - class ProjectPermissionsTable: """ Capture project operations and queries @@ -37,7 +32,6 @@ class ProjectPermissionsTable: def __init__( self, connection: Connection | None, - allow_full_access: bool | None = None, database_connection: Database | None = None, ): self._connection = connection @@ -51,8 +45,6 @@ def __init__( else: self.connection = database_connection - self.gtable = GroupTable(self.connection, allow_full_access=allow_full_access) - async def audit_log_id(self): """ Generate (or return) a audit_log_id by inserting a row into the database @@ -65,33 +57,54 @@ async def audit_log_id(self): # region AUTH async def get_projects_accessible_by_user( - self, user: str, return_all_projects: bool = False + self, user: str ) -> tuple[dict[int, Project], dict[str, Project]]: """ Get projects that are accessible by the specified user """ parameters: dict[str, str] = { 'user': user, + 'project_creators_group_name': GROUP_NAME_PROJECT_CREATORS, + 'members_admin_group_name': GROUP_NAME_MEMBERS_ADMIN, } - # In most cases we want to exclude projects that the user doesn't explicitly - # have access to. If the user is in the project creators group it may be - # necessary to return all projects whether the user has explict access to them - # or not. - where_cond = 'WHERE pm.member = :user' if return_all_projects is False else '' - - _query = f""" + _query = """ + -- Check what admin groups the user belongs to, if they belong + -- to project-creators then a project_admin role will be added to + -- all projects, if they belong to members-admin then a `project_member_admin` + -- role will be appended to all projects. + WITH admin_roles AS ( + SELECT + CASE (g.name) + WHEN :project_creators_group_name THEN 'project_admin' + WHEN :members_admin_group_name THEN 'project_member_admin' + END + as role + FROM `group` g + JOIN group_member gm + ON gm.group_id = g.id + WHERE gm.member = :user + AND g.name in (:project_creators_group_name, :members_admin_group_name) + ), + -- Combine together the project roles and the admin roles + project_roles AS ( + SELECT pm.project_id, pm.member, pm.role + FROM project_member pm + WHERE pm.member = :user + UNION ALL + SELECT p.id as project_id, :user as member, ar.role + FROM project p + JOIN admin_roles ar ON TRUE + ) SELECT p.id, p.name, p.meta, p.dataset, - GROUP_CONCAT(pm.role) as roles + GROUP_CONCAT(pr.role) as roles FROM project p - LEFT JOIN project_member pm - ON p.id = pm.project_id - AND pm.member = :user - {where_cond} + JOIN project_roles pr + ON p.id = pr.project_id GROUP BY p.id """ @@ -107,22 +120,40 @@ async def get_projects_accessible_by_user( return project_id_map, project_name_map + async def check_if_member_in_group_by_name(self, group_name: str, member: str): + """Check if a user exists in the group""" + + _query = """ + SELECT COUNT(*) > 0 + FROM group_member gm + INNER JOIN `group` g ON g.id = gm.group_id + WHERE g.name = :group_name + AND gm.member = :member + """ + value = await self.connection.fetch_val( + _query, {'group_name': group_name, 'member': member} + ) + if value not in (0, 1): + raise ValueError( + f'Unexpected value {value!r} when determining access to {group_name} ' + f'for {member}' + ) + + return bool(value) + async def check_project_creator_permissions(self, author: str): """Check author has project_creator permissions""" - # check permissions in here - is_in_group = await self.gtable.check_if_member_in_group_name( + is_in_group = await self.check_if_member_in_group_by_name( group_name=GROUP_NAME_PROJECT_CREATORS, member=author ) - if not is_in_group: - raise Forbidden(f'{author} does not have access to creating project') + raise Forbidden(f'{author} does not have access to create a project') return True async def check_member_admin_permissions(self, author: str): """Check author has member_admin permissions""" - # check permissions in here - is_in_group = await self.gtable.check_if_member_in_group_name( + is_in_group = await self.check_if_member_in_group_by_name( GROUP_NAME_MEMBERS_ADMIN, author ) if not is_in_group: @@ -141,16 +172,14 @@ async def create_project( project_name: str, dataset_name: str, author: str, - check_permissions: bool = True, ): """Create project row""" - if check_permissions: - await self.check_project_creator_permissions(author) + await self.check_project_creator_permissions(author) async with self.connection.transaction(): _query = """\ - INSERT INTO project (name, dataset, audit_log_id, read_group_id, write_group_id) - VALUES (:name, :dataset, :audit_log_id, :read_group_id, :write_group_id) + INSERT INTO project (name, dataset, audit_log_id) + VALUES (:name, :dataset, :audit_log_id) RETURNING ID""" values = { 'name': project_name, @@ -160,6 +189,9 @@ async def create_project( project_id = await self.connection.fetch_val(_query, values) + if self._connection: + await self._connection.refresh_projects() + return project_id async def update_project(self, project_name: str, update: dict, author: str): @@ -238,7 +270,7 @@ async def delete_project_data(self, project_id: int, delete_project: bool) -> bo return True async def set_project_members( - self, project: Project, members: list[ProjectMemberWithRole] + self, project: Project, members: list[ProjectMemberUpdate] ): """ Set group members for a group (by name) @@ -271,6 +303,11 @@ async def set_project_members( new_audit_log_id = await self.audit_log_id() + db_members: list[dict[str, str]] = [] + + for m in members: + db_members.extend([{'member': m.member, 'role': r} for r in m.roles]) + await self.connection.execute_many( """ INSERT INTO project_member @@ -286,163 +323,12 @@ async def set_project_members( (m['member'], m['role']), new_audit_log_id ), } - for m in members + for m in db_members if m['role'] in project_member_role_names ], ) - # endregion CREATE / UPDATE - - -class GroupTable: - """ - Capture Group table operations and queries - """ - - table_name = 'group' - - def __init__(self, connection: Database, allow_full_access: bool | None = None): - if not isinstance(connection, Database): - raise ValueError( - f'Invalid type connection, expected Database, got {type(connection)}, ' - 'did you forget to call connection.connection?' - ) - self.connection: Database = connection - self.allow_full_access = ( - allow_full_access if allow_full_access is not None else is_all_access() - ) - - async def get_group_members(self, group_id: int) -> set[str]: - """Get project IDs for sampleIds (mostly for checking auth)""" - _query = """ - SELECT member - FROM `group` - WHERE id = :group_id - """ - rows = await self.connection.fetch_all(_query, {'group_id': group_id}) - members = set(r['member'] for r in rows) - return members - - async def get_group_name_from_id(self, name: str) -> int: - """Get group name to group id""" - _query = """ - SELECT id, name - FROM `group` - WHERE name = :name - """ - row = await self.connection.fetch_one(_query, {'name': name}) - if not row: - raise NotFoundError(f'Could not find group {name}') - return row['id'] - - async def get_group_name_to_ids(self, names: list[str]) -> dict[str, int]: - """Get group name to group id""" - _query = """ - SELECT id, name - FROM `group` - WHERE name IN :names - """ - rows = await self.connection.fetch_all(_query, {'names': names}) - return {r['name']: r['id'] for r in rows} - - async def check_if_member_in_group(self, group_id: int, member: str) -> bool: - """Check if a member is in a group""" - if self.allow_full_access: - return True - - _query = """ - SELECT COUNT(*) > 0 - FROM group_member gm - WHERE gm.group_id = :group_id - AND gm.member = :member - """ - value = await self.connection.fetch_val( - _query, {'group_id': group_id, 'member': member} - ) - if value not in (0, 1): - raise ValueError( - f'Unexpected value {value!r} when determining access to group with ID ' - f'{group_id} for {member}' - ) - return bool(value) - - async def check_if_member_in_group_name(self, group_name: str, member: str) -> bool: - """Check if a member is in a group""" - if self.allow_full_access: - return True - - _query = """ - SELECT COUNT(*) > 0 - FROM group_member gm - INNER JOIN `group` g ON g.id = gm.group_id - WHERE g.name = :group_name - AND gm.member = :member - """ - value = await self.connection.fetch_val( - _query, {'group_name': group_name, 'member': member} - ) - if value not in (0, 1): - raise ValueError( - f'Unexpected value {value!r} when determining access to {group_name} ' - f'for {member}' - ) - - return bool(value) - - async def check_which_groups_member_has( - self, group_ids: set[int], member: str - ) -> set[int]: - """ - Check which groups a member has - """ - if self.allow_full_access: - return group_ids - - _query = """ - SELECT gm.group_id as gid - FROM group_member gm - WHERE gm.member = :member AND gm.group_id IN :group_ids - """ - results = await self.connection.fetch_all( - _query, {'group_ids': group_ids, 'member': member} - ) - return set(r['gid'] for r in results) - - async def create_group(self, name: str, audit_log_id: int) -> int: - """Create a new group""" - _query = """ - INSERT INTO `group` (name, audit_log_id) - VALUES (:name, :audit_log_id) - RETURNING id - """ - return await self.connection.fetch_val( - _query, {'name': name, 'audit_log_id': audit_log_id} - ) + if self._connection: + await self._connection.refresh_projects() - async def set_group_members( - self, group_id: int, members: list[str], audit_log_id: int - ): - """ - Set group members for a group (by id) - """ - await self.connection.execute( - """ - DELETE FROM group_member - WHERE group_id = :group_id - """, - {'group_id': group_id}, - ) - await self.connection.execute_many( - """ - INSERT INTO group_member (group_id, member, audit_log_id) - VALUES (:group_id, :member, :audit_log_id) - """, - [ - { - 'group_id': group_id, - 'member': member, - 'audit_log_id': audit_log_id, - } - for member in members - ], - ) + # endregion CREATE / UPDATE diff --git a/models/models/project.py b/models/models/project.py index a2729a046..6f9cc184f 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -9,11 +9,19 @@ ProjectId = int ProjectMemberRole = Enum( - 'ProjectMemberRole', ['reader', 'contributor', 'writer', 'data_manager'] + 'ProjectMemberRole', + [ + 'reader', + 'contributor', + 'writer', + 'data_manager', + 'project_admin', + 'project_member_admin', + ], ) -project_member_role_names = [r.name for r in ProjectMemberRole] +AdminRoles = {ProjectMemberRole.project_admin, ProjectMemberRole.project_member_admin} # These roles have read access to a project ReadAccessRoles = { ProjectMemberRole.reader, @@ -25,6 +33,11 @@ # Only write has full write access FullWriteAccessRoles = {ProjectMemberRole.writer} +project_member_role_names = [r.name for r in ProjectMemberRole] +updatable_project_member_role_names = [ + r.name for r in ProjectMemberRole if r not in AdminRoles +] + class Project(SMBase): """Row for project in 'project' table""" @@ -46,7 +59,8 @@ def is_test(self): return self.name == f'{self.dataset}-test' @field_serializer('roles') - def serialize_roles(self, roles: set[ProjectMemberRole], _info): + def serialize_roles(self, roles: set[ProjectMemberRole]): + """convert roles into a form that can be returned from the API""" return [r.name for r in roles] @staticmethod @@ -61,3 +75,10 @@ def from_db(kwargs): ProjectMemberRole[r] for r in role_list if r in project_member_role_names } return Project(**kwargs) + + +class ProjectMemberUpdate(SMBase): + """Item included in list of project member updates""" + + member: str + roles: list[str] From 606302827710d35a86a8b9c418132ab09d3f8724 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 18 Jun 2024 16:05:00 +1000 Subject: [PATCH 13/29] remove allow all access setting, it is better to use real controls even when running locally, it is better to lean on the actual access controls rather than allowing all access. This way we can catch issues with permission checks during development --- api/server.py | 5 ++--- api/settings.py | 16 ++-------------- api/utils/db.py | 11 ++++++----- 3 files changed, 10 insertions(+), 22 deletions(-) diff --git a/api/server.py b/api/server.py index 9a57ff1cb..c0b2e5188 100644 --- a/api/server.py +++ b/api/server.py @@ -12,11 +12,10 @@ from api import routes from api.graphql.schema import MetamistGraphQLRouter # type: ignore -from api.settings import PROFILE_REQUESTS, SKIP_DATABASE_CONNECTION +from api.settings import PROFILE_REQUESTS, SKIP_DATABASE_CONNECTION, SM_ENVIRONMENT from api.utils.exceptions import determine_code_from_error from api.utils.openapi import get_openapi_schema_func from db.python.connect import SMConnections -from db.python.tables.project import is_all_access from db.python.utils import get_logger # This tag is automatically updated by bump2version @@ -55,7 +54,7 @@ async def app_lifespan(_: FastAPI): app.add_middleware(PyInstrumentProfilerMiddleware) # type: ignore -if is_all_access(): +if SM_ENVIRONMENT == 'local': app.add_middleware( CORSMiddleware, allow_origins=['*'], diff --git a/api/settings.py b/api/settings.py index adae52b89..91d58c1e8 100644 --- a/api/settings.py +++ b/api/settings.py @@ -9,9 +9,8 @@ LOG_DATABASE_QUERIES = ( os.getenv('SM_LOG_DATABASE_QUERIES', 'false').lower() in TRUTH_SET ) -_ALLOW_ALL_ACCESS: bool = os.getenv('SM_ALLOWALLACCESS', 'n').lower() in TRUTH_SET _DEFAULT_USER = os.getenv('SM_LOCALONLY_DEFAULTUSER') -SM_ENVIRONMENT = os.getenv('SM_ENVIRONMENT', 'local').lower() +SM_ENVIRONMENT = os.getenv('SM_ENVIRONMENT', 'production').lower() SKIP_DATABASE_CONNECTION = bool(os.getenv('SM_SKIP_DATABASE_CONNECTION')) PROFILE_REQUESTS = os.getenv('SM_PROFILE_REQUESTS', 'false').lower() in TRUTH_SET IGNORE_GCP_CREDENTIALS_ERROR = os.getenv('SM_IGNORE_GCP_CREDENTIALS_ERROR') in TRUTH_SET @@ -60,22 +59,11 @@ def get_default_user() -> str | None: """Determine if a default user is available""" - if is_all_access() and _DEFAULT_USER: + if SM_ENVIRONMENT == 'local' and _DEFAULT_USER: return _DEFAULT_USER return None -def is_all_access() -> bool: - """Does SM have full access""" - return _ALLOW_ALL_ACCESS - - -def set_all_access(access: bool): - """Set full_access for future use""" - global _ALLOW_ALL_ACCESS - _ALLOW_ALL_ACCESS = access - - @lru_cache def get_slack_token(allow_empty=False): """Get slack token""" diff --git a/api/utils/db.py b/api/utils/db.py index aa0cb38a0..951f63c20 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -60,6 +60,12 @@ def authenticate( If a token (OR Google IAP auth jwt) is provided, return the email, else raise an Exception """ + + if default_user := get_default_user(): + # this should only happen in LOCAL environments + logging.info(f'Using {default_user} as authenticated user') + return default_user + if x_goog_iap_jwt_assertion: # We have to PREFER the IAP's identity, otherwise you could have a case where # the JWT is forged, but IAP lets it through and authenticates, but then we take @@ -71,11 +77,6 @@ def authenticate( if token: return email_from_id_token(token.credentials) - if default_user := get_default_user(): - # this should only happen in LOCAL environments - logging.info(f'Using {default_user} as authenticated user') - return default_user - raise HTTPException(status_code=401, detail='Not authenticated :(') From 32144c88911ebbdce8b9d92ae48b303dfc899d44 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 18 Jun 2024 16:06:00 +1000 Subject: [PATCH 14/29] change graphql project list to only list projects with certain roles to avoid listing absolutely every project for users with admin roles --- api/graphql/schema.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 42edd8e72..bbc800122 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -52,7 +52,7 @@ ) from models.models.analysis_runner import AnalysisRunnerInternal from models.models.family import PedRowInternal -from models.models.project import ProjectId, ReadAccessRoles +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles from models.models.sample import sample_id_transform_to_raw from models.utils.cohort_id_format import cohort_id_format, cohort_id_transform_to_raw from models.utils.cohort_template_id_format import ( @@ -1213,7 +1213,9 @@ async def my_projects( self, info: Info[GraphQLContext, 'Query'] ) -> list[GraphQLProject]: connection = info.context['connection'] - projects = connection.all_projects() + projects = connection.projects_with_role( + ReadAccessRoles.union(FullWriteAccessRoles) + ) return [GraphQLProject.from_internal(p) for p in projects] @strawberry.field From 4641f325cfcb756a9689d4a82481ff785d53227a Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 18 Jun 2024 16:06:29 +1000 Subject: [PATCH 15/29] fix data generation scripts to work with new permissions --- test/data/generate_data.py | 9 +++ test/data/generate_seqr_project_data.py | 103 +++++++++++++++++------- 2 files changed, 83 insertions(+), 29 deletions(-) diff --git a/test/data/generate_data.py b/test/data/generate_data.py index 9cd04b690..678f49002 100755 --- a/test/data/generate_data.py +++ b/test/data/generate_data.py @@ -3,6 +3,7 @@ import argparse import asyncio import datetime +import os import random from pathlib import Path from pprint import pprint @@ -74,6 +75,14 @@ async def main(ped_path=default_ped_location, project='greek-myth'): await papi.create_project_async( name=project, dataset=project, create_test_project=False ) + default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') + + await papi.update_project_members_async( + project=project, + project_member_update=[ + {'member': default_user, 'roles': ['reader', 'writer']} + ], + ) with open(ped_path, encoding='utf-8') as f: # skip the first line diff --git a/test/data/generate_seqr_project_data.py b/test/data/generate_seqr_project_data.py index 5a3f7f5bf..c2f7cb858 100644 --- a/test/data/generate_seqr_project_data.py +++ b/test/data/generate_seqr_project_data.py @@ -4,6 +4,7 @@ import csv import datetime import logging +import os import random import sys import tempfile @@ -13,9 +14,9 @@ from metamist.apis import AnalysisApi, FamilyApi, ParticipantApi, ProjectApi, SampleApi from metamist.graphql import gql, query_async from metamist.model.analysis import Analysis -from metamist.model.analysis_status import AnalysisStatus from metamist.models import AssayUpsert, SampleUpsert, SequencingGroupUpsert from metamist.parser.generic_parser import chunk +from models.enums import AnalysisStatus NAMES = [ 'SOLAR', @@ -91,6 +92,7 @@ class ped_row: """The pedigree row class""" + def __init__(self, values): ( self.family_id, @@ -156,24 +158,28 @@ def generate_pedigree_rows(num_families=1): individual_id = generate_random_id(used_ids) if parent_sex == 1: rows.append( - ped_row([ - family_id, - individual_id, - parent_id, - '', - random.choice([0, 1]), - 2,] + ped_row( + [ + family_id, + individual_id, + parent_id, + '', + random.choice([0, 1]), + 2, + ] ) ) else: rows.append( - ped_row([ - family_id, - individual_id, - '', - parent_id, - random.choice([0, 1]), - 2,] + ped_row( + [ + family_id, + individual_id, + '', + parent_id, + random.choice([0, 1]), + 2, + ] ) ) @@ -201,13 +207,24 @@ def generate_pedigree_rows(num_families=1): 0 ] # Randomly assign affected status rows.append( - ped_row([family_id, individual_id, paternal_id, maternal_id, sex, affected]) + ped_row( + [ + family_id, + individual_id, + paternal_id, + maternal_id, + sex, + affected, + ] + ) ) return rows -def generate_sequencing_type(count_distribution: dict[int, float], sequencing_types: list[str]): +def generate_sequencing_type( + count_distribution: dict[int, float], sequencing_types: list[str] +): """Return a random length of random sequencing types""" k = random.choices( list(count_distribution.keys()), @@ -272,7 +289,9 @@ async def generate_sample_entries( samples = [] for participant_eid, participant_id in participant_id_map.items(): - nsamples = generate_random_number_within_distribution(default_count_probabilities) + nsamples = generate_random_number_within_distribution( + default_count_probabilities + ) for i in range(nsamples): sample = SampleUpsert( external_id=f'{participant_eid}_{i+1}', @@ -289,7 +308,9 @@ async def generate_sample_entries( ) samples.append(sample) - for stype in generate_sequencing_type(default_count_probabilities, sequencing_types): + for stype in generate_sequencing_type( + default_count_probabilities, sequencing_types + ): facility = random.choice( [ 'Amazing sequence centre', @@ -309,13 +330,17 @@ async def generate_sample_entries( assays=[], ) sample.sequencing_groups.append(sg) - for _ in range(generate_random_number_within_distribution(default_count_probabilities)): + for _ in range( + generate_random_number_within_distribution( + default_count_probabilities + ) + ): sg.assays.append( AssayUpsert( type='sequencing', meta={ 'facility': facility, - 'reads' : [], + 'reads': [], 'coverage': f'{random.choice([30, 90, 300, 9000, "?"])}x', 'sequencing_type': stype, 'sequencing_technology': stechnology, @@ -339,7 +364,7 @@ async def generate_cram_analyses(project: str, analyses_to_insert: list[Analysis # Randomly allocate some of the sequencing groups to be aligned aligned_sgs = random.sample( sequencing_groups, - k=random.randint(int(len(sequencing_groups)/2), len(sequencing_groups)) + k=random.randint(int(len(sequencing_groups) / 2), len(sequencing_groups)), ) # Insert completed CRAM analyses for the aligned sequencing groups @@ -351,11 +376,14 @@ async def generate_cram_analyses(project: str, analyses_to_insert: list[Analysis status=AnalysisStatus('completed'), output=f'FAKE://{project}/crams/{sg["id"]}.cram', timestamp_completed=( - datetime.datetime.now() - datetime.timedelta(days=random.randint(1, 15)) + datetime.datetime.now() + - datetime.timedelta(days=random.randint(1, 15)) ).isoformat(), meta={ # random size between 5, 25 GB - 'size': random.randint(5 * 1024, 25 * 1024) * 1024 * 1024, + 'size': random.randint(5 * 1024, 25 * 1024) + * 1024 + * 1024, }, ) for sg in aligned_sgs @@ -365,7 +393,9 @@ async def generate_cram_analyses(project: str, analyses_to_insert: list[Analysis return aligned_sgs -async def generate_joint_called_analyses(project: str, aligned_sgs: list[dict], analyses_to_insert: list[Analysis]): +async def generate_joint_called_analyses( + project: str, aligned_sgs: list[dict], analyses_to_insert: list[Analysis] +): """ Selects a subset of the aligned sequencing groups for the input project and generates joint-called AnnotateDataset and ES-index analysis entries for them. @@ -373,7 +403,9 @@ async def generate_joint_called_analyses(project: str, aligned_sgs: list[dict], seq_type_to_sg_list = { 'genome': [sg['id'] for sg in aligned_sgs if sg['type'] == 'genome'], 'exome': [sg['id'] for sg in aligned_sgs if sg['type'] == 'exome'], - 'transcriptome': [sg['id'] for sg in aligned_sgs if sg['type'] == 'transcriptome'] + 'transcriptome': [ + sg['id'] for sg in aligned_sgs if sg['type'] == 'transcriptome' + ], } for seq_type, sg_list in seq_type_to_sg_list.items(): if not sg_list: @@ -395,7 +427,7 @@ async def generate_joint_called_analyses(project: str, aligned_sgs: list[dict], status=AnalysisStatus('completed'), output=f'FAKE::{project}-{seq_type}-es-{datetime.date.today()}', meta={'stage': 'MtToEs', 'sequencing_type': seq_type}, - ) + ), ] ) @@ -420,9 +452,20 @@ async def main(): await papi.create_project_async( name=project, dataset=project, create_test_project=False ) + + default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') + + await papi.update_project_members_async( + project=project, + project_member_update=[ + {'member': default_user, 'roles': ['reader', 'writer']} + ], + ) + logging.info(f'Created project "{project}"') await papi.update_project_async( - project=project, body={'meta': {'is_seqr': 'true'}}, + project=project, + body={'meta': {'is_seqr': 'true'}}, ) logging.info(f'Set {project} as seqr project') @@ -436,7 +479,9 @@ async def main(): for analyses in chunk(analyses_to_insert, 50): logging.info(f'Inserting {len(analyses)} analysis entries') - await asyncio.gather(*[aapi.create_analysis_async(project, a) for a in analyses]) + await asyncio.gather( + *[aapi.create_analysis_async(project, a) for a in analyses] + ) if __name__ == '__main__': From eeca5e65e99ea8b1fee4240b0e82fcd6df4e3b3c Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 18 Jun 2024 16:06:41 +1000 Subject: [PATCH 16/29] fix tests to work with new permissions structures --- test/test_assay.py | 13 +-- test/test_pedigree.py | 4 +- test/test_project_groups.py | 214 ++++++++++++++++-------------------- test/test_sample.py | 4 +- test/test_search.py | 2 +- test/testbase.py | 61 ++++++++-- 6 files changed, 155 insertions(+), 143 deletions(-) diff --git a/test/test_assay.py b/test/test_assay.py index 6c26d26ad..5e6fab93e 100644 --- a/test/test_assay.py +++ b/test/test_assay.py @@ -53,7 +53,7 @@ async def test_not_found_assay(self): @run_as_sync async def get(): - return await self.assaylayer.get_assay_by_id(-1, check_project_id=False) + return await self.assaylayer.get_assay_by_id(-1) self.assertRaises(NotFoundError, get) @@ -78,9 +78,7 @@ async def test_upsert_assay(self): ) ) - assay = await self.assaylayer.get_assay_by_id( - assay_id=upserted_assay.id, check_project_id=False - ) + assay = await self.assaylayer.get_assay_by_id(assay_id=upserted_assay.id) self.assertEqual(upserted_assay.id, assay.id) self.assertEqual(self.sample_id_raw, int(assay.sample_id)) @@ -377,9 +375,7 @@ async def test_update(self): ) ) - update_assay = await self.assaylayer.get_assay_by_id( - assay_id=assay.id, check_project_id=False - ) + update_assay = await self.assaylayer.get_assay_by_id(assay_id=assay.id) self.assertEqual(assay.id, update_assay.id) self.assertEqual(self.sample_id_raw, int(update_assay.sample_id)) @@ -408,8 +404,7 @@ async def test_update_type(self): # cycle through all statuses, and check that works await self.assaylayer.upsert_assay( - AssayUpsertInternal(id=assay.id, type='metabolomics'), - check_project_id=False, + AssayUpsertInternal(id=assay.id, type='metabolomics') ) row_to_check = await self.connection.connection.fetch_one( 'SELECT type FROM assay WHERE id = :id', diff --git a/test/test_pedigree.py b/test/test_pedigree.py index 8966c8cc3..6bccac966 100644 --- a/test/test_pedigree.py +++ b/test/test_pedigree.py @@ -24,7 +24,7 @@ async def test_import_get_pedigree(self): ) pedigree_dicts = await fl.get_pedigree( - project=self.connection.project, + project=self.connection.project_id, replace_with_participant_external_ids=True, replace_with_family_external_ids=True, ) @@ -60,7 +60,7 @@ async def test_pedigree_without_family(self): ) rows = await fl.get_pedigree( - project=self.connection.project, + project=self.connection.project_id, include_participants_not_in_families=True, replace_with_participant_external_ids=True, ) diff --git a/test/test_project_groups.py b/test/test_project_groups.py index e545f6986..76c2d00ca 100644 --- a/test/test_project_groups.py +++ b/test/test_project_groups.py @@ -7,8 +7,12 @@ Forbidden, ProjectPermissionsTable, ) -from db.python.utils import NotFoundError -from models.models.group import GroupProjectRole, ReadAccessRoles +from models.models.project import ( + FullWriteAccessRoles, + ProjectMemberRole, + ProjectMemberUpdate, + ReadAccessRoles, +) class TestGroupAccess(DbIsolatedTest): @@ -22,9 +26,9 @@ async def setUp(self): super().setUp() # specifically required to test permissions - self.pttable = ProjectPermissionsTable(self.connection, False) + self.pttable = ProjectPermissionsTable(self.connection) - async def _add_group_member_direct(self, group_name): + async def _add_group_member_direct(self, group_name: str): """ Helper method to directly add members to group with name """ @@ -44,87 +48,6 @@ async def _add_group_member_direct(self, group_name): }, ) - @run_as_sync - async def test_group_set_members_failed_no_permission(self): - """ - Test that a user without permission cannot set members - """ - with self.assertRaises(Forbidden): - await self.pttable.set_group_members( - 'another-test-project', ['user1'], self.author - ) - - @run_as_sync - async def test_group_set_members_failed_not_exists(self): - """ - Test that a user with permission, cannot set members - for a group that doesn't exist - """ - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - with self.assertRaises(NotFoundError): - await self.pttable.set_group_members( - 'another-test-project', ['user1'], self.author - ) - - @run_as_sync - async def test_group_set_members_succeeded(self): - """ - Test that a user with permission, can set members for a group that exists - """ - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - - g = str(uuid.uuid4()) - await self.pttable.gtable.create_group(g, await self.audit_log_id()) - - self.assertFalse( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user1') - ) - self.assertFalse( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user2') - ) - - await self.pttable.set_group_members( - group_name=g, members=['user1', 'user2'], author=self.author - ) - - self.assertTrue( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user1') - ) - self.assertTrue( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user2') - ) - - @run_as_sync - async def test_check_which_groups_member_is_missing(self): - """Test the check_which_groups_member_has function""" - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - - group = str(uuid.uuid4()) - gid = await self.pttable.gtable.create_group(group, await self.audit_log_id()) - present_gids = await self.pttable.gtable.check_which_groups_member_has( - {gid}, self.author - ) - missing_gids = {gid} - present_gids - self.assertEqual(1, len(missing_gids)) - self.assertEqual(gid, missing_gids.pop()) - - @run_as_sync - async def test_check_which_groups_member_is_missing_none(self): - """Test the check_which_groups_member_has function""" - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - - group = str(uuid.uuid4()) - gid = await self.pttable.gtable.create_group(group, await self.audit_log_id()) - await self.pttable.gtable.set_group_members( - gid, [self.author], audit_log_id=await self.audit_log_id() - ) - present_gids = await self.pttable.gtable.check_which_groups_member_has( - group_ids={gid}, member=self.author - ) - missing_gids = {gid} - present_gids - - self.assertEqual(0, len(missing_gids)) - @run_as_sync async def test_project_creators_failed(self): """ @@ -147,12 +70,13 @@ async def test_project_create_succeed(self): project_id = await self.pttable.create_project(g, g, self.author) # pylint: disable=protected-access - await self.pttable._get_project_by_id(project_id) + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) - # test that the group names make sense - # @TODO update tests to handle new role groups - # self.assertIsNotNone(project.read_group_id) - # self.assertIsNotNone(project.write_group_id) + project = project_id_map.get(project_id) + assert project + self.assertEqual(project.name, g) class TestProjectAccess(DbIsolatedTest): @@ -164,9 +88,12 @@ async def setUp(self): super().setUp() # specifically required to test permissions - self.pttable = ProjectPermissionsTable(self.connection, False) + self.pttable = ProjectPermissionsTable(self.connection) - async def _add_group_member_direct(self, group_name): + async def _add_group_member_direct( + self, + group_name: str, + ): """ Helper method to directly add members to group with name """ @@ -196,13 +123,13 @@ async def test_no_project_access(self): project_id = await self.pttable.create_project(g, g, self.author) with self.assertRaises(Forbidden): - await self.pttable.get_and_check_access_to_project_for_id( - user=self.author, project_id=project_id, allowed_roles=ReadAccessRoles + self.connection.check_access_to_projects_for_ids( + project_ids=[project_id], allowed_roles=ReadAccessRoles ) with self.assertRaises(Forbidden): - await self.pttable.get_and_check_access_to_project_for_name( - user=self.author, project_name=g, allowed_roles=ReadAccessRoles + self.connection.get_and_check_access_to_projects_for_names( + project_names=[g], allowed_roles=ReadAccessRoles ) @run_as_sync @@ -216,23 +143,64 @@ async def test_project_access_success(self): g = str(uuid.uuid4()) pid = await self.pttable.create_project(g, g, self.author) - await self.pttable.set_group_members( - group_name=self.pttable.get_project_group_name( - g, role=GroupProjectRole.read - ), - members=[self.author], - author=self.author, + + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) + project = project_id_map.get(pid) + assert project + await self.pttable.set_project_members( + project=project, + members=[ProjectMemberUpdate(member=self.author, roles=['reader'])], ) - project_for_id = await self.pttable.get_and_check_access_to_project_for_id( - user=self.author, project_id=pid, allowed_roles=ReadAccessRoles + project_for_id = self.connection.get_and_check_access_to_projects_for_ids( + project_ids=[pid], allowed_roles=ReadAccessRoles ) - self.assertEqual(pid, project_for_id.id) + user_project_for_id = next(p for p in project_for_id) + self.assertEqual(pid, user_project_for_id.id) - project_for_name = await self.pttable.get_and_check_access_to_project_for_name( - user=self.author, project_name=g, allowed_roles=ReadAccessRoles + project_for_name = self.connection.get_and_check_access_to_projects_for_names( + project_names=[g], allowed_roles=ReadAccessRoles ) - self.assertEqual(pid, project_for_name.id) + user_project_for_name = next(p for p in project_for_name) + self.assertEqual(g, user_project_for_name.name) + + @run_as_sync + async def test_project_access_insufficient(self): + """ + Test that a user with access to a project will be disallowed if their access is + not sufficient + """ + await self._add_group_member_direct(GROUP_NAME_PROJECT_CREATORS) + await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) + + g = str(uuid.uuid4()) + + pid = await self.pttable.create_project(g, g, self.author) + + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) + project = project_id_map.get(pid) + assert project + # Give the user read access to the project + await self.pttable.set_project_members( + project=project, + members=[ProjectMemberUpdate(member=self.author, roles=['reader'])], + ) + + # But require Write access + + with self.assertRaises(Forbidden): + self.connection.check_access_to_projects_for_ids( + project_ids=[project.id], allowed_roles=FullWriteAccessRoles + ) + + with self.assertRaises(Forbidden): + self.connection.get_and_check_access_to_projects_for_names( + project_names=[g], allowed_roles=FullWriteAccessRoles + ) @run_as_sync async def test_get_my_projects(self): @@ -246,17 +214,27 @@ async def test_get_my_projects(self): pid = await self.pttable.create_project(g, g, self.author) - await self.pttable.set_group_members( - group_name=self.pttable.get_project_group_name( - g, role=GroupProjectRole.read - ), - members=[self.author], - author=self.author, + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) + project = project_id_map.get(pid) + assert project + # Give the user read access to the project + await self.pttable.set_project_members( + project=project, + members=[ProjectMemberUpdate(member=self.author, roles=['contributor'])], + ) + + project_id_map, project_name_map = ( + await self.pttable.get_projects_accessible_by_user(user=self.author) ) - projects = await self.pttable.get_projects_accessible_by_user( - user=self.author, allowed_roles=ReadAccessRoles + # Get projects with at least a read access role + my_projects = self.connection.projects_with_role( + {ProjectMemberRole.contributor} ) + print(my_projects) - self.assertEqual(1, len(projects)) - self.assertEqual(pid, projects[0].id) + self.assertEqual(len(project_id_map), len(project_name_map)) + self.assertEqual(len(my_projects), 1) + self.assertEqual(pid, my_projects[0].id) diff --git a/test/test_sample.py b/test/test_sample.py index e5b8639b7..454e434ac 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -45,7 +45,7 @@ async def test_get_sample(self): ) ) - sample = await self.slayer.get_by_id(s.id, check_project_id=False) + sample = await self.slayer.get_by_id(s.id) self.assertEqual('blood', sample.type) self.assertDictEqual(meta_dict, sample.meta) @@ -68,7 +68,7 @@ async def test_update_sample(self): SampleUpsertInternal(id=s.id, external_id=new_external_id) ) - sample = await self.slayer.get_by_id(s.id, check_project_id=False) + sample = await self.slayer.get_by_id(s.id) self.assertEqual(new_external_id, sample.external_id) diff --git a/test/test_search.py b/test/test_search.py index d5498b098..45e0e309c 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -223,7 +223,7 @@ async def test_search_mixed(self): ) all_results = await self.schlay.search( - query='X:', project_ids=[self.connection.project] + query='X:', project_ids=[self.connection.project_id] ) self.assertEqual(3, len(all_results)) diff --git a/test/testbase.py b/test/testbase.py index 9c8121821..d03520a23 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -17,7 +17,6 @@ from api.graphql.loaders import get_context # type: ignore from api.graphql.schema import schema # type: ignore -from api.settings import set_all_access from db.python.connect import ( TABLES_ORDERED_BY_FK_DEPS, Connection, @@ -25,8 +24,7 @@ SMConnections, ) from db.python.tables.project import ProjectPermissionsTable -from models.models.group import FullWriteAccessRoles -from models.models.project import ProjectId +from models.models.project import Project, ProjectId, ProjectMemberUpdate # use this to determine where the db directory is relatively, # as pycharm runs in "test/" folder, and GH runs them in git root @@ -82,6 +80,8 @@ class DbTest(unittest.TestCase): author: str project_id: ProjectId project_name: str + project_id_map: dict[ProjectId, Project] + project_name_map: dict[str, Project] @classmethod def setUpClass(cls) -> None: @@ -97,7 +97,6 @@ async def setup(): """ logger = logging.getLogger() try: - set_all_access(True) db = MySqlContainer('mariadb:11.2.2') port_to_expose = find_free_port() # override the default port to map the container to @@ -138,23 +137,56 @@ async def setup(): formed_connection = Connection( connection=sm_db, author=cls.author, - allowed_roles=FullWriteAccessRoles, + project_id_map={}, + project_name_map={}, on_behalf_of=None, ar_guid=None, project=None, ) + + # Add the test user to the admin groups + await formed_connection.connection.execute( + f""" + INSERT INTO group_member(group_id, member) + SELECT id, '{cls.author}' + FROM `group` WHERE name IN('project-creators', 'members-admin') + """ + ) + ppt = ProjectPermissionsTable( connection=formed_connection, - allow_full_access=True, ) cls.project_name = 'test' cls.project_id = await ppt.create_project( project_name=cls.project_name, dataset_name=cls.project_name, author=cls.author, - check_permissions=False, ) - formed_connection.project = cls.project_id + + _, initial_project_name_map = await ppt.get_projects_accessible_by_user( + user=cls.author + ) + created_project = initial_project_name_map.get(cls.project_name) + assert created_project + + await ppt.set_project_members( + project=created_project, + members=[ + ProjectMemberUpdate( + member=cls.author, roles=['reader', 'writer'] + ) + ], + ) + + # Get the new project map now that project membership is updated + project_id_map, project_name_map = ( + await ppt.get_projects_accessible_by_user(user=cls.author) + ) + + cls.project_id_map = project_id_map + cls.project_name_map = project_name_map + + # formed_connection.project = cls.project_id except subprocess.CalledProcessError as e: logging.exception(e) @@ -186,9 +218,10 @@ def setUp(self) -> None: # audit_log ID for each test self.connection = Connection( connection=self._connection, - project=self.project_id, + project=self.project_id_map.get(self.project_id), author=self.author, - allowed_roles=FullWriteAccessRoles, + project_id_map=self.project_id_map, + project_name_map=self.project_name_map, ar_guid=None, on_behalf_of=None, ) @@ -249,7 +282,13 @@ class DbIsolatedTest(DbTest): async def setUp(self) -> None: super().setUp() - ignore = {'DATABASECHANGELOG', 'DATABASECHANGELOGLOCK', 'project', 'group'} + ignore = { + 'DATABASECHANGELOG', + 'DATABASECHANGELOGLOCK', + 'project', + 'group', + 'project-member', + } for table in TABLES_ORDERED_BY_FK_DEPS: if table in ignore: continue From f0326f57ec9f05edf91346bbc7ab18ad646857fd Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 18 Jun 2024 16:39:28 +1000 Subject: [PATCH 17/29] fix merge issues --- db/python/layers/family.py | 2 +- db/python/layers/participant.py | 2 +- db/python/tables/participant.py | 2 +- db/python/tables/sample.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/db/python/layers/family.py b/db/python/layers/family.py index d3c1a7317..c3bbbf6cf 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -12,7 +12,7 @@ from db.python.tables.participant import ParticipantTable from db.python.tables.sample import SampleTable from db.python.utils import GenericFilter, NotFoundError -from models.models import PRIMARY_EXTERNAL_ORG, ProjectId +from models.models import PRIMARY_EXTERNAL_ORG from models.models.family import FamilyInternal, PedRow, PedRowInternal from models.models.participant import ParticipantUpsertInternal from models.models.project import ProjectId, ReadAccessRoles diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index 6fd331c36..39fb7a17d 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -21,7 +21,7 @@ NotFoundError, split_generic_terms, ) -from models.models import PRIMARY_EXTERNAL_ORG, ProjectId +from models.models import PRIMARY_EXTERNAL_ORG from models.models.family import PedRowInternal from models.models.participant import ParticipantInternal, ParticipantUpsertInternal from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles diff --git a/db/python/tables/participant.py b/db/python/tables/participant.py index 19bb007c5..13f728633 100644 --- a/db/python/tables/participant.py +++ b/db/python/tables/participant.py @@ -122,7 +122,7 @@ async def create_participant( """ _eid_values = [ { - 'project': project or self.project, + 'project': project or self.project_id, 'pid': new_id, 'name': name.lower(), 'external_id': eid, diff --git a/db/python/tables/sample.py b/db/python/tables/sample.py index cb522454a..ba8dee1ab 100644 --- a/db/python/tables/sample.py +++ b/db/python/tables/sample.py @@ -190,7 +190,7 @@ async def insert_sample( """ _eid_values = [ { - 'project': project or self.project, + 'project': project or self.project_id, 'id': id_of_new_sample, 'name': name.lower(), 'external_id': eid, From 84b88b3e3ad72957de69aea6340ce80fec7b2487 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 18 Jun 2024 16:39:40 +1000 Subject: [PATCH 18/29] update docs --- docs/installation.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 50652b764..be1016c5a 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -288,7 +288,7 @@ export SM_DEV_DB_PORT=3306 # or 3307 You'll want to set the following environment variables (permanently) in your local development environment. -The `SM_ENVIRONMENT`, `SM_LOCALONLY_DEFAULTUSER` and `SM_ALLOWALLACCESS` environment variables allow access to a local metamist server without providing a bearer token. +The `SM_ENVIRONMENT` and `SM_LOCALONLY_DEFAULTUSER` environment variables allow access to a local metamist server without providing a bearer token. This will allow you to test the front-end components that access data. This happens automatically on the production instance through the Google identity-aware-proxy. @@ -296,12 +296,19 @@ This will allow you to test the front-end components that access data. This happ # ensures the SWAGGER page points to your local: (localhost:8000/docs) # and ensures if you use the PythonAPI, it also points to your local export SM_ENVIRONMENT=LOCAL -# skips permission checks in your local environment -export SM_ALLOWALLACCESS=true # uses your username as the "author" in requests export SM_LOCALONLY_DEFAULTUSER=$(whoami) ``` +To allow the sytem to be bootstrapped and create the initial project, you'll need to add yourself to the two admin groups that allow creating projects and updating project members: + +```sql +INSERT INTO group_member(group_id, member) +SELECT id, '' +FROM `group` WHERE name IN('project-creators', 'members-admin') + +``` + With those variables set, it is a good time to populate some test data if this is your first time running this server: ```bash @@ -335,7 +342,6 @@ The following `launch.json` is a good base to debug the web server in VS Code: "module": "api.server", "justMyCode": false, "env": { - "SM_ALLOWALLACCESS": "true", "SM_LOCALONLY_DEFAULTUSER": "-local", "SM_ENVIRONMENT": "local", "SM_DEV_DB_USER": "sm_api", From 93aefcf73fad3cc1e82e17843fe6621edc126ede Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 11:03:13 +1000 Subject: [PATCH 19/29] fix permission checks in project routes --- api/routes/project.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/api/routes/project.py b/api/routes/project.py index 0a809ecc0..1ee5b1f35 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -83,9 +83,6 @@ async def update_project( """Update a project by project name""" ptable = ProjectPermissionsTable(connection) - # Updating a project additionally requires the project creator permission - await ptable.check_project_creator_permissions(author=connection.author) - project = connection.project assert project return await ptable.update_project( @@ -97,7 +94,7 @@ async def update_project( async def delete_project_data( delete_project: bool = False, connection: Connection = get_project_db_connection( - {ProjectMemberRole.project_admin} + {ProjectMemberRole.project_admin, ProjectMemberRole.data_manager} ), ): """ @@ -116,7 +113,9 @@ async def delete_project_data( ) if not data_manager_deleting_test: # Otherwise, deleting a project additionally requires the project creator permission - await ptable.check_project_creator_permissions(author=connection.author) + connection.check_access_to_projects_for_ids( + [connection.project.id], allowed_roles={ProjectMemberRole.project_admin} + ) success = await ptable.delete_project_data( project_id=connection.project.id, @@ -139,8 +138,6 @@ async def update_project_members( """ ptable = ProjectPermissionsTable(connection) - await ptable.check_member_admin_permissions(author=connection.author) - for member in members: for role in member.roles: if role not in updatable_project_member_role_names: From 5cef94219458b6ff983700c4a409cf70a486ac7f Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 11:05:23 +1000 Subject: [PATCH 20/29] make query uppercase for consistency --- db/python/tables/project.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/db/python/tables/project.py b/db/python/tables/project.py index 390a18cd4..b71b22136 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -283,9 +283,9 @@ async def set_project_members( # Get existing rows so that we can keep the existing audit log ids existing_rows = await self.connection.fetch_all( """ - select project_id, member, role, audit_log_id - from project_member - where project_id = :project_id + SELECT project_id, member, role, audit_log_id + FROM project_member + WHERE project_id = :project_id """, {'project_id': project.id}, ) From 76554e31abbd9fd096901304c6b9722652a82ef8 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 11:13:02 +1000 Subject: [PATCH 21/29] use test environment for tests --- api/settings.py | 2 +- test/testbase.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/api/settings.py b/api/settings.py index 91d58c1e8..c16fae52b 100644 --- a/api/settings.py +++ b/api/settings.py @@ -59,7 +59,7 @@ def get_default_user() -> str | None: """Determine if a default user is available""" - if SM_ENVIRONMENT == 'local' and _DEFAULT_USER: + if SM_ENVIRONMENT in ('local', 'test') and _DEFAULT_USER: return _DEFAULT_USER return None diff --git a/test/testbase.py b/test/testbase.py index d03520a23..04eacffe9 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -95,6 +95,8 @@ async def setup(): Then you can destroy the database within tearDownClass as all tests have been completed. """ + # Set environment to test + os.environ['SM_ENVIRONMENT'] = 'test' logger = logging.getLogger() try: db = MySqlContainer('mariadb:11.2.2') From aab26312cd9d7cfc21a62a225fee0cb477580423 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 11:36:43 +1000 Subject: [PATCH 22/29] make connection class variables protected --- db/python/connect.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/db/python/connect.py b/db/python/connect.py index 665d4937b..38e3f81d6 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -66,10 +66,10 @@ def __init__( ar_guid: str | None, meta: dict[str, str] | None = None, ): - self.connection: databases.Database = connection - self.project: Project | None = project - self.project_id_map = project_id_map - self.project_name_map = project_name_map + self.__connection: databases.Database = connection + self.__project: Project | None = project + self.__project_id_map = project_id_map + self.__project_name_map = project_name_map self.author: str = author self.on_behalf_of: str | None = on_behalf_of self.ar_guid: str | None = ar_guid @@ -78,6 +78,26 @@ def __init__( self._audit_log_id: int | None = None self._audit_log_lock = asyncio.Lock() + @property + def connection(self): + """Public getter for private project class variable""" + return self.__connection + + @property + def project(self): + """Public getter for private project class variable""" + return self.__project + + @property + def project_id_map(self): + """Public getter for private project_id_map class variable""" + return self.__project_id_map + + @property + def project_name_map(self): + """Public getter for private project_name_map class variable""" + return self.__project_name_map + @property def project_id(self): """Safely get the project id from the project model attached to the connection""" @@ -240,11 +260,11 @@ async def refresh_projects(self): project_id_map, project_name_map = await pt.get_projects_accessible_by_user( user=self.author ) - self.project_id_map = project_id_map - self.project_name_map = project_name_map + self.__project_id_map = project_id_map + self.__project_name_map = project_name_map if self.project_id: - self.project = self.project_id_map.get(self.project_id) + self.__project = self.project_id_map.get(self.project_id) class DatabaseConfiguration(abc.ABC): From 6a32cc11ab0501a7b42485df4ebb6c311901931b Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 11:50:43 +1000 Subject: [PATCH 23/29] add check for default user in generate data scripts --- docs/installation.md | 2 +- test/data/generate_data.py | 5 +++++ test/data/generate_seqr_project_data.py | 5 +++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/installation.md b/docs/installation.md index be1016c5a..b952e948e 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -300,7 +300,7 @@ export SM_ENVIRONMENT=LOCAL export SM_LOCALONLY_DEFAULTUSER=$(whoami) ``` -To allow the sytem to be bootstrapped and create the initial project, you'll need to add yourself to the two admin groups that allow creating projects and updating project members: +To allow the system to be bootstrapped and create the initial project, you'll need to add yourself to the two admin groups that allow creating projects and updating project members: ```sql INSERT INTO group_member(group_id, member) diff --git a/test/data/generate_data.py b/test/data/generate_data.py index d72b01891..ab96a5017 100755 --- a/test/data/generate_data.py +++ b/test/data/generate_data.py @@ -78,6 +78,11 @@ async def main(ped_path=default_ped_location, project='greek-myth'): name=project, dataset=project, create_test_project=False ) default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') + if not default_user: + print( + "SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data" + ) + exit(1) await papi.update_project_members_async( project=project, diff --git a/test/data/generate_seqr_project_data.py b/test/data/generate_seqr_project_data.py index bb299aaec..5069eb653 100644 --- a/test/data/generate_seqr_project_data.py +++ b/test/data/generate_seqr_project_data.py @@ -456,6 +456,11 @@ async def main(): ) default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') + if not default_user: + print( + "SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data" + ) + exit(1) await papi.update_project_members_async( project=project, From ed7599b7a5e884b5c8d3b63888e7bcd49df3117e Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 11:56:57 +1000 Subject: [PATCH 24/29] fix linting errors --- test/data/generate_data.py | 7 ++++--- test/data/generate_seqr_project_data.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/data/generate_data.py b/test/data/generate_data.py index ab96a5017..079d8a1cd 100755 --- a/test/data/generate_data.py +++ b/test/data/generate_data.py @@ -5,6 +5,7 @@ import datetime import os import random +import sys from pathlib import Path from pprint import pprint from uuid import uuid4 @@ -80,9 +81,9 @@ async def main(ped_path=default_ped_location, project='greek-myth'): default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') if not default_user: print( - "SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data" + 'SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data' ) - exit(1) + sys.exit(1) await papi.update_project_members_async( project=project, @@ -122,7 +123,7 @@ def generate_random_number_within_distribution(): )[0] samples = [] - sample_id_index = 1003 + sample_id_index = 10000 for participant_eid in participant_eids: pid = id_map[participant_eid] diff --git a/test/data/generate_seqr_project_data.py b/test/data/generate_seqr_project_data.py index 5069eb653..026c53e31 100644 --- a/test/data/generate_seqr_project_data.py +++ b/test/data/generate_seqr_project_data.py @@ -458,9 +458,9 @@ async def main(): default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') if not default_user: print( - "SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data" + 'SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data' ) - exit(1) + sys.exit(1) await papi.update_project_members_async( project=project, From 24d3341ba983da397fe7fc23095792fde790abf8 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 12:31:00 +1000 Subject: [PATCH 25/29] simplify roles to allow better management from cpg-infrastructure --- api/routes/project.py | 17 +++-------------- db/project.xml | 2 +- models/models/project.py | 16 ---------------- 3 files changed, 4 insertions(+), 31 deletions(-) diff --git a/api/routes/project.py b/api/routes/project.py index 1ee5b1f35..c4fd03315 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -12,7 +12,7 @@ ProjectMemberRole, ProjectMemberUpdate, ReadAccessRoles, - updatable_project_member_role_names, + project_member_role_names, ) router = APIRouter(prefix='/project', tags=['project']) @@ -94,7 +94,7 @@ async def update_project( async def delete_project_data( delete_project: bool = False, connection: Connection = get_project_db_connection( - {ProjectMemberRole.project_admin, ProjectMemberRole.data_manager} + {ProjectMemberRole.project_admin} ), ): """ @@ -106,17 +106,6 @@ async def delete_project_data( assert connection.project - # Allow data manager role to delete test projects - data_manager_deleting_test = ( - connection.project.is_test - and connection.project.roles & {ProjectMemberRole.data_manager} - ) - if not data_manager_deleting_test: - # Otherwise, deleting a project additionally requires the project creator permission - connection.check_access_to_projects_for_ids( - [connection.project.id], allowed_roles={ProjectMemberRole.project_admin} - ) - success = await ptable.delete_project_data( project_id=connection.project.id, delete_project=delete_project, @@ -140,7 +129,7 @@ async def update_project_members( for member in members: for role in member.roles: - if role not in updatable_project_member_role_names: + if role not in project_member_role_names: raise HTTPException( 400, f'Role {role} is not valid for member {member.member}' ) diff --git a/db/project.xml b/db/project.xml index 0ac4ce512..6a520d8d7 100644 --- a/db/project.xml +++ b/db/project.xml @@ -1429,7 +1429,7 @@ references="project(id)" /> - + diff --git a/models/models/project.py b/models/models/project.py index 6f9cc184f..abb41f6d7 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -14,29 +14,22 @@ 'reader', 'contributor', 'writer', - 'data_manager', 'project_admin', 'project_member_admin', ], ) -AdminRoles = {ProjectMemberRole.project_admin, ProjectMemberRole.project_member_admin} # These roles have read access to a project ReadAccessRoles = { ProjectMemberRole.reader, ProjectMemberRole.contributor, ProjectMemberRole.writer, - ProjectMemberRole.data_manager, } # Only write has full write access FullWriteAccessRoles = {ProjectMemberRole.writer} - project_member_role_names = [r.name for r in ProjectMemberRole] -updatable_project_member_role_names = [ - r.name for r in ProjectMemberRole if r not in AdminRoles -] class Project(SMBase): @@ -49,15 +42,6 @@ class Project(SMBase): roles: set[ProjectMemberRole] """The roles that the current user has within the project""" - @property - def is_test(self): - """ - Checks whether this is a test project. Comparing to the dataset is safer than - just checking whether the name ends with -test, just in case we have a non-test - project that happens to end with -test - """ - return self.name == f'{self.dataset}-test' - @field_serializer('roles') def serialize_roles(self, roles: set[ProjectMemberRole]): """convert roles into a form that can be returned from the API""" From 482783135e96903ed24cf08ed5d0e96b6cfa3b8b Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Tue, 25 Jun 2024 12:50:50 +1000 Subject: [PATCH 26/29] re-reorder auth checks --- api/utils/db.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/api/utils/db.py b/api/utils/db.py index 951f63c20..aa0cb38a0 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -60,12 +60,6 @@ def authenticate( If a token (OR Google IAP auth jwt) is provided, return the email, else raise an Exception """ - - if default_user := get_default_user(): - # this should only happen in LOCAL environments - logging.info(f'Using {default_user} as authenticated user') - return default_user - if x_goog_iap_jwt_assertion: # We have to PREFER the IAP's identity, otherwise you could have a case where # the JWT is forged, but IAP lets it through and authenticates, but then we take @@ -77,6 +71,11 @@ def authenticate( if token: return email_from_id_token(token.credentials) + if default_user := get_default_user(): + # this should only happen in LOCAL environments + logging.info(f'Using {default_user} as authenticated user') + return default_user + raise HTTPException(status_code=401, detail='Not authenticated :(') From fea03e83a0bcc02a6366410f2bc69dc3893c75f6 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Mon, 1 Jul 2024 12:40:24 +1000 Subject: [PATCH 27/29] merge cleanup --- api/routes/web.py | 18 +++++++----------- db/python/layers/participant.py | 15 ++++----------- db/python/layers/web.py | 2 +- db/python/tables/sample.py | 2 -- db/python/utils.py | 2 +- test/test_project_groups.py | 2 +- 6 files changed, 14 insertions(+), 27 deletions(-) diff --git a/api/routes/web.py b/api/routes/web.py index a2284d0d6..c601b49d9 100644 --- a/api/routes/web.py +++ b/api/routes/web.py @@ -49,10 +49,6 @@ class SearchResponseModel(SMBase): operation_id='getProjectSummary', ) async def get_project_summary( - request: Request, - grid_filter: list[SearchItem], - limit: int = 20, - token: Optional[int] = 0, connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> ProjectSummary: """Creates a new sample, and returns the internal sample ID""" @@ -69,7 +65,7 @@ async def get_project_summary( operation_id='getProjectParticipantsFilterSchema', ) async def get_project_project_participants_filter_schema( - _=get_project_readonly_connection, + _=get_project_db_connection(ReadAccessRoles), ): """Get project summary (from query) with some limit""" return ProjectParticipantGridFilter.model_json_schema() @@ -84,15 +80,15 @@ async def get_project_participants_grid_with_limit( limit: int, query: ProjectParticipantGridFilter, skip: int = 0, - connection=get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get project summary (from query) with some limit""" - if not connection.project: + if not connection.project_id: raise ValueError('No project was detected through the authentication') wlayer = WebLayer(connection) - pfilter = query.to_internal(project=connection.project) + pfilter = query.to_internal(project=connection.project_id) participants, pcount = await asyncio.gather( wlayer.query_participants(pfilter, limit=limit, skip=skip), @@ -120,15 +116,15 @@ async def export_project_participants( export_type: ExportType, query: ProjectParticipantGridFilter, fields: ExportProjectParticipantFields | None = None, - connection=get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get project summary (from query) with some limit""" - if not connection.project: + if not connection.project_id: raise ValueError('No project was detected through the authentication') wlayer = WebLayer(connection) - pfilter = query.to_internal(project=connection.project) + pfilter = query.to_internal(project=connection.project_id) participants_internal = await wlayer.query_participants(pfilter, limit=None) participants = [p.to_external() for p in participants_internal] diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index c8c643c3e..b2754eafb 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -16,12 +16,7 @@ from db.python.tables.participant import ParticipantFilter, ParticipantTable from db.python.tables.participant_phenotype import ParticipantPhenotypeTable from db.python.tables.sample import SampleTable -from db.python.utils import ( - GenericFilter, - NoOpAenter, - NotFoundError, - split_generic_terms, -) +from db.python.utils import NoOpAenter, NotFoundError, split_generic_terms from models.models import PRIMARY_EXTERNAL_ORG from models.models.family import PedRowInternal from models.models.participant import ParticipantInternal, ParticipantUpsertInternal @@ -250,7 +245,6 @@ async def query( filter_: ParticipantFilter, limit: int | None = None, skip: int | None = None, - check_project_ids: bool = True, ) -> list[ParticipantInternal]: """ Query participants from the database, heavy lifting done by the filter @@ -262,10 +256,9 @@ async def query( if not participants: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return participants diff --git a/db/python/layers/web.py b/db/python/layers/web.py index 874484909..fb0f1ff2a 100644 --- a/db/python/layers/web.py +++ b/db/python/layers/web.py @@ -91,7 +91,7 @@ async def get_total_number_of_assays(self): FROM assay sq INNER JOIN sample s ON s.id = sq.sample_id WHERE s.project = :project""" - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) def get_seqr_links_from_project(self, project: WebProject) -> dict[str, str]: """ diff --git a/db/python/tables/sample.py b/db/python/tables/sample.py index 03c176824..e4362e76e 100644 --- a/db/python/tables/sample.py +++ b/db/python/tables/sample.py @@ -270,8 +270,6 @@ async def insert_sample( } for name, eid in external_ids.items() if eid is not None - for name, eid in external_ids.items() - if eid is not None ] await self.connection.execute_many(_eid_query, _eid_values) diff --git a/db/python/utils.py b/db/python/utils.py index 996093fbf..70c3b49c9 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -2,7 +2,7 @@ import logging import os import re -from typing import TypeVar +from typing import Any, TypeVar T = TypeVar('T') X = TypeVar('X') diff --git a/test/test_project_groups.py b/test/test_project_groups.py index 133669039..e7db46e9e 100644 --- a/test/test_project_groups.py +++ b/test/test_project_groups.py @@ -1,4 +1,5 @@ import uuid +from test.testbase import DbIsolatedTest, run_as_sync from db.python.tables.project import ( GROUP_NAME_MEMBERS_ADMIN, @@ -12,7 +13,6 @@ ProjectMemberUpdate, ReadAccessRoles, ) -from test.testbase import DbIsolatedTest, run_as_sync class TestGroupAccess(DbIsolatedTest): From 4da577216430c7145ef69368cf3f9bbbdcfbaa56 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Thu, 4 Jul 2024 16:09:28 +1000 Subject: [PATCH 28/29] update seqr project listing to raise if user can't access all seqr projs --- api/routes/project.py | 17 +++++++++++++++-- db/python/tables/project.py | 9 +++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/api/routes/project.py b/api/routes/project.py index c4fd03315..0df175436 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -6,6 +6,7 @@ get_projectless_db_connection, ) from db.python.tables.project import ProjectPermissionsTable +from db.python.utils import Forbidden from models.models.project import ( FullWriteAccessRoles, Project, @@ -69,8 +70,20 @@ async def create_project( @router.get('/seqr/all', operation_id='getSeqrProjects') async def get_seqr_projects(connection: Connection = get_projectless_db_connection): """Get SM projects that should sync to seqr""" - projects = connection.all_projects() - return [p for p in projects if p.meta and p.meta.get('is_seqr')] + ptable = ProjectPermissionsTable(connection) + seqr_project_ids = await ptable.get_seqr_project_ids() + my_seqr_projects = [ + p for p in connection.all_projects() if p.id in seqr_project_ids + ] + + # Fail if user doesn't have access to all seqr projects. This endpoint is used + # for joint-calling where we would want to include all seqr projects and know if + # any are missing, so it is important to raise an error here rather than just + # excluding projects due to permission issues + if len(my_seqr_projects) != len(seqr_project_ids): + raise Forbidden('The current user does not have access to all seqr projects') + + return my_seqr_projects @router.post('/{project}/update', operation_id='updateProject') diff --git a/db/python/tables/project.py b/db/python/tables/project.py index b71b22136..74fa1c74f 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -120,6 +120,15 @@ async def get_projects_accessible_by_user( return project_id_map, project_name_map + async def get_seqr_project_ids(self) -> list[int]: + """ + Get all projects with meta.is_seqr = true + """ + rows = await self.connection.fetch_all( + "SELECT id FROM project WHERE JSON_VALUE(meta, '$.is_seqr') = 1" + ) + return [r['id'] for r in rows] + async def check_if_member_in_group_by_name(self, group_name: str, member: str): """Check if a user exists in the group""" From fbc2897002829ca4c7f542ec855d42beaf8fddb6 Mon Sep 17 00:00:00 2001 From: Dan Coates Date: Thu, 4 Jul 2024 16:20:46 +1000 Subject: [PATCH 29/29] =?UTF-8?q?Bump=20version:=207.1.1=20=E2=86=92=207.2?= =?UTF-8?q?.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- api/server.py | 2 +- deploy/python/version.txt | 2 +- setup.py | 2 +- web/package.json | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 4eda41cd7..655a25f06 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 7.1.1 +current_version = 7.2.0 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P[A-z0-9-]+) diff --git a/api/server.py b/api/server.py index aa1c7b827..bc2116a27 100644 --- a/api/server.py +++ b/api/server.py @@ -19,7 +19,7 @@ from db.python.utils import get_logger # This tag is automatically updated by bump2version -_VERSION = '7.1.1' +_VERSION = '7.2.0' logger = get_logger() diff --git a/deploy/python/version.txt b/deploy/python/version.txt index 21c8c7b46..0ee843cc6 100644 --- a/deploy/python/version.txt +++ b/deploy/python/version.txt @@ -1 +1 @@ -7.1.1 +7.2.0 diff --git a/setup.py b/setup.py index c2078672e..64203f264 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( name=PKG, # This tag is automatically updated by bump2version - version='7.1.1', + version='7.2.0', description='Python API for interacting with the Sample API system', long_description=readme, long_description_content_type='text/markdown', diff --git a/web/package.json b/web/package.json index 88f97a346..8451098af 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "metamist", - "version": "7.1.1", + "version": "7.2.0", "private": true, "dependencies": { "@apollo/client": "^3.7.3",