diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 4eda41cd7..655a25f06 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 7.1.1 +current_version = 7.2.0 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P[A-z0-9-]+) diff --git a/api/graphql/filters.py b/api/graphql/filters.py index d00a2b1b7..8799b94ac 100644 --- a/api/graphql/filters.py +++ b/api/graphql/filters.py @@ -23,11 +23,11 @@ class GraphQLFilter(Generic[T]): icontains: T | None = None startswith: T | None = None - def all_values(self): + def all_values(self) -> list[T]: """ Get all values used anywhere in a filter, useful for getting values to map later """ - v = [] + v: list[T] = [] if self.eq: v.append(self.eq) if self.in_: diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 3c0bd2996..c6c1fed2e 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -4,12 +4,14 @@ import dataclasses import enum from collections import defaultdict -from typing import Any +from typing import Any, TypedDict from fastapi import Request from strawberry.dataloader import DataLoader -from api.utils import get_projectless_db_connection, group_by +from api.utils import group_by +from api.utils.db import get_projectless_db_connection +from db.python.connect import Connection from db.python.filters import GenericFilter, get_hashable_value from db.python.layers import ( AnalysisLayer, @@ -24,7 +26,6 @@ from db.python.tables.assay import AssayFilter from db.python.tables.family import FamilyFilter from db.python.tables.participant import ParticipantFilter -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter from db.python.utils import NotFoundError @@ -79,14 +80,14 @@ class LoaderKeys(enum.Enum): SEQUENCING_GROUPS_FOR_ANALYSIS = 'sequencing_groups_for_analysis' -loaders = {} +loaders: dict[LoaderKeys, Any] = {} -def connected_data_loader(id_: LoaderKeys, cache=True): +def connected_data_loader(id_: LoaderKeys, cache: bool = True): """Provide connection to a data loader""" def connected_data_loader_caller(fn): - def inner(connection): + def inner(connection: Connection): async def wrapped(*args, **kwargs): return await fn(*args, **kwargs, connection=connection) @@ -110,7 +111,7 @@ def connected_data_loader_with_params( """ def connected_data_loader_caller(fn): - def inner(connection): + def inner(connection: Connection): async def wrapped(query: list[dict[str, Any]]) -> list[Any]: by_key: dict[tuple, Any] = {} @@ -159,7 +160,7 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]: @connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS) async def load_audit_logs_by_ids( - audit_log_ids: list[int], connection + audit_log_ids: list[int], connection: Connection ) -> list[AuditLogInternal | None]: """ DataLoader: get_audit_logs_by_ids @@ -172,7 +173,7 @@ async def load_audit_logs_by_ids( @connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS) async def load_audit_logs_by_analysis_ids( - analysis_ids: list[int], connection + analysis_ids: list[int], connection: Connection ) -> list[list[AuditLogInternal]]: """ DataLoader: get_audit_logs_by_analysis_ids @@ -184,7 +185,7 @@ async def load_audit_logs_by_analysis_ids( @connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list) async def load_assays_by_samples( - connection, ids, filter: AssayFilter + connection: Connection, ids, filter: AssayFilter ) -> dict[int, list[AssayInternal]]: """ DataLoader: get_assays_for_sample_ids @@ -200,7 +201,7 @@ async def load_assays_by_samples( @connected_data_loader(LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS) async def load_assays_by_sequencing_groups( - sequencing_group_ids: list[int], connection + sequencing_group_ids: list[int], connection: Connection ) -> list[list[AssayInternal]]: """ Get all assays belong to the sequencing groups @@ -209,7 +210,7 @@ async def load_assays_by_sequencing_groups( # group by all last fields, in case we add more assays = await assaylayer.get_assays_for_sequencing_group_ids( - sequencing_group_ids=sequencing_group_ids, check_project_ids=False + sequencing_group_ids=sequencing_group_ids ) return [assays.get(sg, []) for sg in sequencing_group_ids] @@ -219,7 +220,7 @@ async def load_assays_by_sequencing_groups( LoaderKeys.SAMPLES_FOR_PARTICIPANTS, default_factory=list ) async def load_samples_for_participant_ids( - ids: list[int], filter: SampleFilter, connection + ids: list[int], filter: SampleFilter, connection: Connection ) -> dict[int, list[SampleInternal]]: """ DataLoader: get_samples_for_participant_ids @@ -232,7 +233,7 @@ async def load_samples_for_participant_ids( @connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_IDS) async def load_sequencing_groups_for_ids( - sequencing_group_ids: list[int], connection + sequencing_group_ids: list[int], connection: Connection ) -> list[SequencingGroupInternal]: """ DataLoader: get_sequencing_groups_by_ids @@ -249,7 +250,7 @@ async def load_sequencing_groups_for_ids( LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES, default_factory=list ) async def load_sequencing_groups_for_samples( - connection, ids: list[int], filter: SequencingGroupFilter + connection: Connection, ids: list[int], filter: SequencingGroupFilter ) -> dict[int, list[SequencingGroupInternal]]: """ Has format [(sample_id: int, sequencing_type?: string)] @@ -270,7 +271,7 @@ async def load_sequencing_groups_for_samples( @connected_data_loader(LoaderKeys.SAMPLES_FOR_IDS) async def load_samples_for_ids( - sample_ids: list[int], connection + sample_ids: list[int], connection: Connection ) -> list[SampleInternal]: """ DataLoader: get_samples_for_ids @@ -286,7 +287,7 @@ async def load_samples_for_ids( LoaderKeys.SAMPLES_FOR_PROJECTS, default_factory=list ) async def load_samples_for_projects( - connection, ids: list[ProjectId], filter: SampleFilter + connection: Connection, ids: list[ProjectId], filter: SampleFilter ): """ DataLoader: get_samples_for_project_ids @@ -300,7 +301,7 @@ async def load_samples_for_projects( @connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_IDS) async def load_participants_for_ids( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[ParticipantInternal]: """ DataLoader: get_participants_by_ids @@ -318,7 +319,7 @@ async def load_participants_for_ids( @connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS) async def load_sequencing_groups_for_analysis_ids( - analysis_ids: list[int], connection + analysis_ids: list[int], connection: Connection ) -> list[list[SequencingGroupInternal]]: """ DataLoader: get_samples_for_analysis_ids @@ -333,7 +334,7 @@ async def load_sequencing_groups_for_analysis_ids( LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS, default_factory=list ) async def load_sequencing_groups_for_project_ids( - ids: list[int], filter: SequencingGroupFilter, connection + ids: list[int], filter: SequencingGroupFilter, connection: Connection ) -> dict[int, list[SequencingGroupInternal]]: """ DataLoader: get_sequencing_groups_for_project_ids @@ -347,39 +348,33 @@ async def load_sequencing_groups_for_project_ids( @connected_data_loader(LoaderKeys.PROJECTS_FOR_IDS) -async def load_projects_for_ids(project_ids: list[int], connection) -> list[Project]: +async def load_projects_for_ids( + project_ids: list[int], connection: Connection +) -> list[Project]: """ Get projects by IDs """ - pttable = ProjectPermissionsTable(connection) - projects = await pttable.get_and_check_access_to_projects_for_ids( - user=connection.author, project_ids=project_ids, readonly=True - ) - - p_by_id = {p.id: p for p in projects} - projects = [p_by_id.get(p) for p in project_ids] + projects = [connection.project_id_map.get(p) for p in project_ids] return [p for p in projects if p is not None] @connected_data_loader(LoaderKeys.FAMILIES_FOR_PARTICIPANTS) async def load_families_for_participants( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[list[FamilyInternal]]: """ Get families of participants, noting a participant can be in multiple families """ flayer = FamilyLayer(connection) - fam_map = await flayer.get_families_by_participants( - participant_ids=participant_ids, check_project_ids=False - ) + fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids) return [fam_map.get(p, []) for p in participant_ids] @connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_FAMILIES) async def load_participants_for_families( - family_ids: list[int], connection + family_ids: list[int], connection: Connection ) -> list[list[ParticipantInternal]]: """Get all participants in a family, doesn't include affected statuses""" player = ParticipantLayer(connection) @@ -391,7 +386,7 @@ async def load_participants_for_families( LoaderKeys.PARTICIPANTS_FOR_PROJECTS, default_factory=list ) async def load_participants_for_projects( - ids: list[ProjectId], filter_: ParticipantFilter, connection + ids: list[ProjectId], filter_: ParticipantFilter, connection: Connection ) -> dict[ProjectId, list[ParticipantInternal]]: """ Get all participants in a project @@ -411,7 +406,7 @@ async def load_participants_for_projects( async def load_analyses_for_sequencing_groups( ids: list[int], filter_: AnalysisFilter, - connection, + connection: Connection, ) -> dict[int, list[AnalysisInternal]]: """ Type: (sequencing_group_id: int, status?: AnalysisStatus, type?: str) @@ -429,7 +424,7 @@ async def load_analyses_for_sequencing_groups( @connected_data_loader(LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS) async def load_phenotypes_for_participants( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[dict]: """ Data loader for phenotypes for participants @@ -443,7 +438,7 @@ async def load_phenotypes_for_participants( @connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS) async def load_families_for_ids( - family_ids: list[int], connection + family_ids: list[int], connection: Connection ) -> list[FamilyInternal]: """ DataLoader: get_families_for_ids @@ -456,7 +451,7 @@ async def load_families_for_ids( @connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES) async def load_family_participants_for_families( - family_ids: list[int], connection + family_ids: list[int], connection: Connection ) -> list[list[PedRowInternal]]: """ DataLoader: get_family_participants_for_families @@ -469,7 +464,7 @@ async def load_family_participants_for_families( @connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS) async def load_family_participants_for_participants( - participant_ids: list[int], connection + participant_ids: list[int], connection: Connection ) -> list[list[PedRowInternal]]: """data loader for family participants for participants @@ -490,12 +485,21 @@ async def load_family_participants_for_participants( return [fp_map.get(pid, []) for pid in participant_ids] +class GraphQLContext(TypedDict): + """Basic dict type for GraphQL context to be passed to resolvers""" + + loaders: dict[LoaderKeys, Any] + connection: Connection + + async def get_context( - request: Request, connection=get_projectless_db_connection -): # pylint: disable=unused-argument + request: Request, # pylint: disable=unused-argument + connection: Connection = get_projectless_db_connection, +) -> GraphQLContext: """Get loaders / cache context for strawberyy GraphQL""" mapped_loaders = {k: fn(connection) for k, fn in loaders.items()} + return { 'connection': connection, - **mapped_loaders, + 'loaders': mapped_loaders, } diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 6f57a6574..2e9df5fae 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -18,7 +18,7 @@ GraphQLMetaFilter, graphql_meta_filter_to_internal_filter, ) -from api.graphql.loaders import LoaderKeys, get_context +from api.graphql.loaders import GraphQLContext, LoaderKeys, get_context from db.python import enum_tables from db.python.filters import GenericFilter from db.python.layers import ( @@ -37,7 +37,6 @@ from db.python.tables.cohort import CohortFilter, CohortTemplateFilter from db.python.tables.family import FamilyFilter from db.python.tables.participant import ParticipantFilter -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter from models.enums import AnalysisStatus @@ -57,7 +56,7 @@ from models.models.analysis_runner import AnalysisRunnerInternal from models.models.family import PedRowInternal from models.models.ourdna import OurDNADashboard, OurDNALostSample -from models.models.project import ProjectId +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles from models.models.sample import sample_id_transform_to_raw from models.utils.cohort_id_format import cohort_id_format, cohort_id_transform_to_raw from models.utils.cohort_template_id_format import ( @@ -76,7 +75,7 @@ continue def create_function(_enum): - async def m(info: Info) -> list[str]: + async def m(info: Info[GraphQLContext, 'Query']) -> list[str]: return await _enum(info.context['connection']).get() m.__name__ = _enum.get_enum_name() @@ -139,20 +138,18 @@ def from_internal(internal: CohortInternal) -> 'GraphQLCohort': @strawberry.field() async def template( - self, info: Info, root: 'GraphQLCohort' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' ) -> 'GraphQLCohortTemplate': connection = info.context['connection'] template = await CohortLayer(connection).get_template_by_cohort_id( cohort_id_transform_to_raw(root.id) ) - ptable = ProjectPermissionsTable(connection) - projects = await ptable.get_and_check_access_to_projects_for_ids( - user=connection.author, + projects = connection.get_and_check_access_to_projects_for_ids( project_ids=( template.criteria.projects if template.criteria.projects else [] ), - readonly=True, + allowed_roles=ReadAccessRoles, ) project_names = [p.name for p in projects if p.name] @@ -162,7 +159,7 @@ async def template( @strawberry.field() async def sequencing_groups( - self, info: Info, root: 'GraphQLCohort' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' ) -> list['GraphQLSequencingGroup']: connection = info.context['connection'] cohort_layer = CohortLayer(connection) @@ -176,10 +173,9 @@ async def sequencing_groups( @strawberry.field() async def analyses( - self, info: Info, root: 'GraphQLCohort' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' ) -> list['GraphQLAnalysis']: connection = info.context['connection'] - connection.project = root.project_id internal_analysis = await AnalysisLayer(connection).query( AnalysisFilter( cohort_id=GenericFilter(in_=[cohort_id_transform_to_raw(root.id)]), @@ -188,8 +184,10 @@ async def analyses( return [GraphQLAnalysis.from_internal(a) for a in internal_analysis] @strawberry.field() - async def project(self, info: Info, root: 'GraphQLCohort') -> 'GraphQLProject': - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohort' + ) -> 'GraphQLProject': + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @@ -221,9 +219,9 @@ def from_internal( @strawberry.field() async def project( - self, info: Info, root: 'GraphQLCohortTemplate' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLCohortTemplate' ) -> 'GraphQLProject': - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @@ -249,7 +247,7 @@ def from_internal(internal: Project) -> 'GraphQLProject': @strawberry.field() async def analysis_runner( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', ar_guid: GraphQLFilter[str] | None = None, author: GraphQLFilter[str] | None = None, @@ -285,7 +283,7 @@ async def ourdna_dashboard( @strawberry.field() async def pedigree( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', internal_family_ids: list[int] | None = None, replace_with_participant_external_ids: bool = True, @@ -313,7 +311,7 @@ async def pedigree( @strawberry.field() async def families( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLProject', id: GraphQLFilter[int] | None = None, external_id: GraphQLFilter[str] | None = None, @@ -333,7 +331,7 @@ async def families( @strawberry.field() async def participants( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLProject', id: GraphQLFilter[int] | None = None, external_id: GraphQLFilter[str] | None = None, @@ -342,7 +340,7 @@ async def participants( reported_gender: GraphQLFilter[str] | None = None, karyotype: GraphQLFilter[str] | None = None, ) -> list['GraphQLParticipant']: - loader = info.context[LoaderKeys.PARTICIPANTS_FOR_PROJECTS] + loader = info.context['loaders'][LoaderKeys.PARTICIPANTS_FOR_PROJECTS] participants = await loader.load( { 'id': root.id, @@ -370,14 +368,14 @@ async def participants( @strawberry.field() async def samples( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLProject', type: GraphQLFilter[str] | None = None, external_id: GraphQLFilter[str] | None = None, id: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, ) -> list['GraphQLSample']: - loader = info.context[LoaderKeys.SAMPLES_FOR_PROJECTS] + loader = info.context['loaders'][LoaderKeys.SAMPLES_FOR_PROJECTS] filter_ = SampleFilter( type=type.to_internal_filter() if type else None, external_id=external_id.to_internal_filter() if external_id else None, @@ -390,7 +388,7 @@ async def samples( @strawberry.field() async def sequencing_groups( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLProject', id: GraphQLFilter[str] | None = None, external_id: GraphQLFilter[str] | None = None, @@ -399,7 +397,7 @@ async def sequencing_groups( platform: GraphQLFilter[str] | None = None, active_only: GraphQLFilter[bool] | None = None, ) -> list['GraphQLSequencingGroup']: - loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS] + loader = info.context['loaders'][LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS] filter_ = SequencingGroupFilter( id=( id.to_internal_filter_mapped(sequencing_group_id_transform_to_raw) @@ -422,7 +420,7 @@ async def sequencing_groups( @strawberry.field() async def analyses( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', type: GraphQLFilter[str] | None = None, status: GraphQLFilter[GraphQLAnalysisStatus] | None = None, @@ -432,7 +430,6 @@ async def analyses( ids: GraphQLFilter[int] | None = None, ) -> list['GraphQLAnalysis']: connection = info.context['connection'] - connection.project = root.id internal_analysis = await AnalysisLayer(connection).query( AnalysisFilter( id=ids.to_internal_filter() if ids else None, @@ -457,7 +454,7 @@ async def analyses( @strawberry.field() async def cohorts( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'Project', id: GraphQLFilter[str] | None = None, name: GraphQLFilter[str] | None = None, @@ -466,7 +463,6 @@ async def cohorts( timestamp: GraphQLFilter[datetime.datetime] | None = None, ) -> list['GraphQLCohort']: connection = info.context['connection'] - connection.project = root.id c_filter = CohortFilter( id=id.to_internal_filter_mapped(cohort_id_transform_to_raw) if id else None, @@ -541,23 +537,25 @@ def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis': @strawberry.field async def sequencing_groups( - self, info: Info, root: 'GraphQLAnalysis' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysis' ) -> list['GraphQLSequencingGroup']: - loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS] + loader = info.context['loaders'][LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS] sgs = await loader.load(root.id) return [GraphQLSequencingGroup.from_internal(sg) for sg in sgs] @strawberry.field - async def project(self, info: Info, root: 'GraphQLAnalysis') -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysis' + ) -> GraphQLProject: + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @strawberry.field async def audit_logs( - self, info: Info, root: 'GraphQLAnalysis' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysis' ) -> list[GraphQLAuditLog]: - loader = info.context[LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS] + loader = info.context['loaders'][LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS] audit_logs = await loader.load(root.id) return [GraphQLAuditLog.from_internal(audit_log) for audit_log in audit_logs] @@ -586,25 +584,27 @@ def from_internal(internal: FamilyInternal) -> 'GraphQLFamily': ) @strawberry.field - async def project(self, info: Info, root: 'GraphQLFamily') -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamily' + ) -> GraphQLProject: + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @strawberry.field async def participants( - self, info: Info, root: 'GraphQLFamily' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamily' ) -> list['GraphQLParticipant']: - participants = await info.context[LoaderKeys.PARTICIPANTS_FOR_FAMILIES].load( - root.id - ) + participants = await info.context['loaders'][ + LoaderKeys.PARTICIPANTS_FOR_FAMILIES + ].load(root.id) return [GraphQLParticipant.from_internal(p) for p in participants] @strawberry.field async def family_participants( - self, info: Info, root: 'GraphQLFamily' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamily' ) -> list['GraphQLFamilyParticipant']: - family_participants = await info.context[ + family_participants = await info.context['loaders'][ LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES ].load(root.id) return [ @@ -627,17 +627,17 @@ class GraphQLFamilyParticipant: @strawberry.field async def participant( - self, info: Info, root: 'GraphQLFamilyParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamilyParticipant' ) -> 'GraphQLParticipant': - loader = info.context[LoaderKeys.PARTICIPANTS_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.PARTICIPANTS_FOR_IDS] participant = await loader.load(root.participant_id) return GraphQLParticipant.from_internal(participant) @strawberry.field async def family( - self, info: Info, root: 'GraphQLFamilyParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLFamilyParticipant' ) -> GraphQLFamily: - loader = info.context[LoaderKeys.FAMILIES_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.FAMILIES_FOR_IDS] family = await loader.load(root.family_id) return GraphQLFamily.from_internal(family) @@ -682,7 +682,7 @@ def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant': @strawberry.field async def samples( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant', type: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, @@ -695,28 +695,32 @@ async def samples( ) q = {'id': root.id, 'filter': filter_} - samples = await info.context[LoaderKeys.SAMPLES_FOR_PARTICIPANTS].load(q) + samples = await info.context['loaders'][ + LoaderKeys.SAMPLES_FOR_PARTICIPANTS + ].load(q) return [GraphQLSample.from_internal(s) for s in samples] @strawberry.field async def phenotypes( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> strawberry.scalars.JSON: - loader = info.context[LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS] + loader = info.context['loaders'][LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS] return await loader.load(root.id) @strawberry.field async def families( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> list[GraphQLFamily]: - fams = await info.context[LoaderKeys.FAMILIES_FOR_PARTICIPANTS].load(root.id) + fams = await info.context['loaders'][LoaderKeys.FAMILIES_FOR_PARTICIPANTS].load( + root.id + ) return [GraphQLFamily.from_internal(f) for f in fams] @strawberry.field async def family_participants( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> list[GraphQLFamilyParticipant]: - family_participants = await info.context[ + family_participants = await info.context['loaders'][ LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS ].load(root.id) return [ @@ -724,18 +728,20 @@ async def family_participants( ] @strawberry.field - async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' + ) -> GraphQLProject: + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.project_id) return GraphQLProject.from_internal(project) @strawberry.field async def audit_log( - self, info: Info, root: 'GraphQLParticipant' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLParticipant' ) -> GraphQLAuditLog | None: if root.audit_log_id is None: return None - loader = info.context[LoaderKeys.AUDIT_LOGS_BY_IDS] + loader = info.context['loaders'][LoaderKeys.AUDIT_LOGS_BY_IDS] audit_log = await loader.load(root.audit_log_id) return GraphQLAuditLog.from_internal(audit_log) @@ -771,23 +777,27 @@ def from_internal(sample: SampleInternal) -> 'GraphQLSample': @strawberry.field async def participant( - self, info: Info, root: 'GraphQLSample' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample' ) -> GraphQLParticipant | None: if root.participant_id is None: return None - loader_participants_for_ids = info.context[LoaderKeys.PARTICIPANTS_FOR_IDS] + loader_participants_for_ids = info.context['loaders'][ + LoaderKeys.PARTICIPANTS_FOR_IDS + ] participant = await loader_participants_for_ids.load(root.participant_id) return GraphQLParticipant.from_internal(participant) @strawberry.field async def assays( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample', type: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, ) -> list['GraphQLAssay']: - loader_assays_for_sample_ids = info.context[LoaderKeys.ASSAYS_FOR_SAMPLES] + loader_assays_for_sample_ids = info.context['loaders'][ + LoaderKeys.ASSAYS_FOR_SAMPLES + ] filter_ = AssayFilter( type=type.to_internal_filter() if type else None, meta=meta, @@ -798,14 +808,18 @@ async def assays( return [GraphQLAssay.from_internal(assay) for assay in assays] @strawberry.field - async def project(self, info: Info, root: 'GraphQLSample') -> GraphQLProject: - project = await info.context[LoaderKeys.PROJECTS_FOR_IDS].load(root.project_id) + async def project( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample' + ) -> GraphQLProject: + project = await info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS].load( + root.project_id + ) return GraphQLProject.from_internal(project) @strawberry.field async def sequencing_groups( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLSample', id: GraphQLFilter[str] | None = None, type: GraphQLFilter[str] | None = None, @@ -814,7 +828,7 @@ async def sequencing_groups( meta: GraphQLMetaFilter | None = None, active_only: GraphQLFilter[bool] | None = None, ) -> list['GraphQLSequencingGroup']: - loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES] + loader = info.context['loaders'][LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES] _filter = SequencingGroupFilter( id=( @@ -870,15 +884,17 @@ def from_internal(internal: SequencingGroupInternal) -> 'GraphQLSequencingGroup' ) @strawberry.field - async def sample(self, info: Info, root: 'GraphQLSequencingGroup') -> GraphQLSample: - loader = info.context[LoaderKeys.SAMPLES_FOR_IDS] + async def sample( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSequencingGroup' + ) -> GraphQLSample: + loader = info.context['loaders'][LoaderKeys.SAMPLES_FOR_IDS] sample = await loader.load(root.sample_id) return GraphQLSample.from_internal(sample) @strawberry.field async def analyses( self, - info: Info, + info: Info[GraphQLContext, 'Query'], root: 'GraphQLSequencingGroup', status: GraphQLFilter[GraphQLAnalysisStatus] | None = None, type: GraphQLFilter[str] | None = None, @@ -887,14 +903,14 @@ async def analyses( project: GraphQLFilter[str] | None = None, ) -> list[GraphQLAnalysis]: connection = info.context['connection'] - loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS] + loader = info.context['loaders'][LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS] _project_filter: GenericFilter[ProjectId] | None = None if project: - ptable = ProjectPermissionsTable(connection) - project_ids = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_ids, readonly=True + project_names = project.all_values() + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, + allowed_roles=ReadAccessRoles, ) project_id_map: dict[str, int] = { p.name: p.id for p in projects if p.name and p.id @@ -923,9 +939,9 @@ async def analyses( @strawberry.field async def assays( - self, info: Info, root: 'GraphQLSequencingGroup' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLSequencingGroup' ) -> list['GraphQLAssay']: - loader = info.context[LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS] + loader = info.context['loaders'][LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS] assays = await loader.load(root.internal_id) return [GraphQLAssay.from_internal(assay) for assay in assays] @@ -956,8 +972,10 @@ def from_internal(internal: AssayInternal) -> 'GraphQLAssay': ) @strawberry.field - async def sample(self, info: Info, root: 'GraphQLAssay') -> GraphQLSample: - loader = info.context[LoaderKeys.SAMPLES_FOR_IDS] + async def sample( + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAssay' + ) -> GraphQLSample: + loader = info.context['loaders'][LoaderKeys.SAMPLES_FOR_IDS] sample = await loader.load(root.sample_id) return GraphQLSample.from_internal(sample) @@ -1011,9 +1029,9 @@ def from_internal(internal: AnalysisRunnerInternal) -> 'GraphQLAnalysisRunner': @strawberry.field async def project( - self, info: Info, root: 'GraphQLAnalysisRunner' + self, info: Info[GraphQLContext, 'Query'], root: 'GraphQLAnalysisRunner' ) -> GraphQLProject: - loader = info.context[LoaderKeys.PROJECTS_FOR_IDS] + loader = info.context['loaders'][LoaderKeys.PROJECTS_FOR_IDS] project = await loader.load(root.internal_project) return GraphQLProject.from_internal(project) @@ -1023,26 +1041,25 @@ class Query: # entry point to graphql. """GraphQL Queries""" @strawberry.field() - def enum(self, info: Info) -> GraphQLEnum: # type: ignore + def enum(self, info: Info[GraphQLContext, 'Query']) -> GraphQLEnum: # type: ignore return GraphQLEnum() @strawberry.field() async def cohort_templates( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, ) -> list[GraphQLCohortTemplate]: connection = info.context['connection'] cohort_layer = CohortLayer(connection) - ptable = ProjectPermissionsTable(connection) project_name_map: dict[str, int] = {} project_filter = None if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, allowed_roles=ReadAccessRoles ) project_name_map = {p.name: p.id for p in projects} project_filter = project.to_internal_filter_mapped( @@ -1063,10 +1080,9 @@ async def cohort_templates( external_templates = [] for template in cohort_templates: - template_projects = await ptable.get_and_check_access_to_projects_for_ids( - user=connection.author, + template_projects = connection.get_and_check_access_to_projects_for_ids( project_ids=template.criteria.projects or [], - readonly=True, + allowed_roles=ReadAccessRoles, ) template_project_names = [p.name for p in template_projects if p.name] external_templates.append( @@ -1078,7 +1094,7 @@ async def cohort_templates( @strawberry.field() async def cohorts( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, name: GraphQLFilter[str] | None = None, @@ -1088,13 +1104,12 @@ async def cohorts( connection = info.context['connection'] cohort_layer = CohortLayer(connection) - ptable = ProjectPermissionsTable(connection) project_name_map: dict[str, int] = {} project_filter = None if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, allowed_roles=ReadAccessRoles ) project_name_map = {p.name: p.id for p in projects} project_filter = project.to_internal_filter_mapped( @@ -1119,18 +1134,19 @@ async def cohorts( return [GraphQLCohort.from_internal(cohort) for cohort in cohorts] @strawberry.field() - async def project(self, info: Info, name: str) -> GraphQLProject: + async def project( + self, info: Info[GraphQLContext, 'Query'], name: str + ) -> GraphQLProject: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection) - project = await ptable.get_and_check_access_to_project_for_name( - user=connection.author, project_name=name, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names=[name], allowed_roles=ReadAccessRoles ) - return GraphQLProject.from_internal(project) + return GraphQLProject.from_internal(next(p for p in projects)) @strawberry.field async def sample( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, type: GraphQLFilter[str] | None = None, @@ -1140,7 +1156,6 @@ async def sample( active: GraphQLFilter[bool] | None = None, ) -> list[GraphQLSample]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection) slayer = SampleLayer(connection) if not id and not project: @@ -1152,8 +1167,9 @@ async def sample( project_name_map: dict[str, int] = {} if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, + allowed_roles=ReadAccessRoles, ) project_name_map = {p.name: p.id for p in projects if p.name and p.id} @@ -1180,7 +1196,7 @@ async def sample( @strawberry.field async def sequencing_groups( self, - info: Info, + info: Info[GraphQLContext, 'Query'], id: GraphQLFilter[str] | None = None, project: GraphQLFilter[str] | None = None, sample_id: GraphQLFilter[str] | None = None, @@ -1195,7 +1211,6 @@ async def sequencing_groups( ) -> list[GraphQLSequencingGroup]: connection = info.context['connection'] sglayer = SequencingGroupLayer(connection) - ptable = ProjectPermissionsTable(connection) if not (project or sample_id or id): raise ValueError('Must filter by project, sample or id') @@ -1204,8 +1219,9 @@ async def sequencing_groups( if project: project_names = project.all_values() - projects = await ptable.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=project_names, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names=project_names, + allowed_roles=ReadAccessRoles, ) project_id_map = {p.name: p.id for p in projects if p.name and p.id} _project_filter = project.to_internal_filter_mapped( @@ -1247,36 +1263,41 @@ async def sequencing_groups( return [GraphQLSequencingGroup.from_internal(sg) for sg in sgs] @strawberry.field - async def assay(self, info: Info, id: int) -> GraphQLAssay: + async def assay(self, info: Info[GraphQLContext, 'Query'], id: int) -> GraphQLAssay: connection = info.context['connection'] slayer = AssayLayer(connection) assay = await slayer.get_assay_by_id(id) return GraphQLAssay.from_internal(assay) @strawberry.field - async def participant(self, info: Info, id: int) -> GraphQLParticipant: - loader = info.context[LoaderKeys.PARTICIPANTS_FOR_IDS] + async def participant( + self, info: Info[GraphQLContext, 'Query'], id: int + ) -> GraphQLParticipant: + loader = info.context['loaders'][LoaderKeys.PARTICIPANTS_FOR_IDS] return GraphQLParticipant.from_internal(await loader.load(id)) @strawberry.field() - async def family(self, info: Info, family_id: int) -> GraphQLFamily: + async def family( + self, info: Info[GraphQLContext, 'Query'], family_id: int + ) -> GraphQLFamily: connection = info.context['connection'] family = await FamilyLayer(connection).get_family_by_internal_id(family_id) return GraphQLFamily.from_internal(family) @strawberry.field - async def my_projects(self, info: Info) -> list[GraphQLProject]: + async def my_projects( + self, info: Info[GraphQLContext, 'Query'] + ) -> list[GraphQLProject]: connection = info.context['connection'] - ptable = ProjectPermissionsTable(connection) - projects = await ptable.get_projects_accessible_by_user( - connection.author, readonly=True + projects = connection.projects_with_role( + ReadAccessRoles.union(FullWriteAccessRoles) ) return [GraphQLProject.from_internal(p) for p in projects] @strawberry.field async def analysis_runner( self, - info: Info, + info: Info[GraphQLContext, 'Query'], ar_guid: str, ) -> GraphQLAnalysisRunner: if not ar_guid: diff --git a/api/routes/analysis.py b/api/routes/analysis.py index 86b37d775..303520111 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -11,8 +11,7 @@ from api.utils.dates import parse_date_only_string from api.utils.db import ( Connection, - get_project_readonly_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from api.utils.export import ExportType @@ -21,10 +20,10 @@ from db.python.layers.analysis_runner import AnalysisRunnerLayer from db.python.tables.analysis import AnalysisFilter from db.python.tables.analysis_runner import AnalysisRunnerFilter -from db.python.tables.project import ProjectPermissionsTable from models.enums import AnalysisStatus from models.models.analysis import Analysis, ProportionalDateTemporalMethod from models.models.analysis_runner import AnalysisRunner +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.utils.sequencing_group_id_format import ( sequencing_group_id_format, sequencing_group_id_format_list, @@ -94,7 +93,8 @@ def to_filter(self, project_id_map: dict[str, int]) -> AnalysisFilter: @router.put('/{project}/', operation_id='createAnalysis', response_model=int) async def create_analysis( - analysis: Analysis, connection: Connection = get_project_write_connection + analysis: Analysis, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> int: """Create a new analysis""" @@ -134,14 +134,14 @@ async def update_analysis( ) async def get_all_sample_ids_without_analysis_type( analysis_type: str, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """get_all_sample_ids_without_analysis_type""" atable = AnalysisLayer(connection) - assert connection.project + assert connection.project_id sequencing_group_ids = ( await atable.get_all_sequencing_group_ids_without_analysis_type( - connection.project, analysis_type + connection.project_id, analysis_type ) ) return { @@ -154,12 +154,12 @@ async def get_all_sample_ids_without_analysis_type( operation_id='getIncompleteAnalyses', ) async def get_incomplete_analyses( - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get analyses with status queued or in-progress""" atable = AnalysisLayer(connection) - assert connection.project - results = await atable.get_incomplete_analyses(project=connection.project) + assert connection.project_id + results = await atable.get_incomplete_analyses(project=connection.project_id) return [r.to_external() for r in results] @@ -169,13 +169,13 @@ async def get_incomplete_analyses( ) async def get_latest_complete_analysis_for_type( analysis_type: str, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get (SINGLE) latest complete analysis for some analysis type""" alayer = AnalysisLayer(connection) - assert connection.project + assert connection.project_id analysis = await alayer.get_latest_complete_analysis_for_type( - project=connection.project, analysis_type=analysis_type + project=connection.project_id, analysis_type=analysis_type ) return analysis.to_external() @@ -187,16 +187,16 @@ async def get_latest_complete_analysis_for_type( async def get_latest_complete_analysis_for_type_post( analysis_type: str, meta: dict[str, Any] = Body(..., embed=True), # type: ignore[assignment] - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """ Get SINGLE latest complete analysis for some analysis type (you can specify meta attributes in this route) """ alayer = AnalysisLayer(connection) - assert connection.project + assert connection.project_id analysis = await alayer.get_latest_complete_analysis_for_type( - project=connection.project, + project=connection.project_id, analysis_type=analysis_type, meta=meta, ) @@ -227,9 +227,8 @@ async def query_analyses( if not query.projects: raise ValueError('Must specify "projects"') - pt = ProjectPermissionsTable(connection) - projects = await pt.get_and_check_access_to_projects_for_names( - user=connection.author, project_names=query.projects, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + query.projects, ReadAccessRoles ) project_name_map = {p.name: p.id for p in projects} atable = AnalysisLayer(connection) @@ -251,9 +250,8 @@ async def get_analysis_runner_log( raise ValueError('Must specify "project_names"') arlayer = AnalysisRunnerLayer(connection) - pt = ProjectPermissionsTable(connection) - projects = await pt.get_and_check_access_to_projects_for_names( - connection.author, project_names, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + project_names, allowed_roles=ReadAccessRoles ) project_ids = [p.id for p in projects if p.id] project_map = {p.id: p.name for p in projects if p.id and p.name} @@ -275,7 +273,7 @@ async def get_analysis_runner_log( async def get_sample_reads_map( export_type: ExportType = ExportType.JSON, sequencing_types: list[str] = Query(None), # type: ignore - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """ Get map of ExternalSampleId pathToCram InternalSeqGroupID for seqr @@ -290,9 +288,9 @@ async def get_sample_reads_map( """ at = AnalysisLayer(connection) - assert connection.project + assert connection.project_id objs = await at.get_sample_cram_path_map_for_seqr( - project=connection.project, sequencing_types=sequencing_types + project=connection.project_id, sequencing_types=sequencing_types ) for r in objs: @@ -308,7 +306,7 @@ async def get_sample_reads_map( writer = csv.writer(output, delimiter=export_type.get_delimiter()) writer.writerows(rows) - basefn = f'{connection.project}-seqr-igv-paths-{date.today().isoformat()}' + basefn = f'{connection.project_id}-seqr-igv-paths-{date.today().isoformat()}' return StreamingResponse( iter([output.getvalue()]), @@ -342,10 +340,10 @@ async def get_proportionate_map( } } """ - pt = ProjectPermissionsTable(connection) - project_ids = await pt.get_project_ids_from_names_and_user( - connection.author, projects, readonly=True + project_list = connection.get_and_check_access_to_projects_for_names( + projects, allowed_roles=ReadAccessRoles ) + project_ids = [p.id for p in project_list] start_date = parse_date_only_string(start) if start else None end_date = parse_date_only_string(end) if end else None diff --git a/api/routes/analysis_runner.py b/api/routes/analysis_runner.py index 0ac0e3925..666e5b4d5 100644 --- a/api/routes/analysis_runner.py +++ b/api/routes/analysis_runner.py @@ -2,15 +2,12 @@ from fastapi import APIRouter -from api.utils.db import ( - Connection, - get_project_readonly_connection, - get_project_write_connection, -) +from api.utils.db import Connection, get_project_db_connection from db.python.filters import GenericFilter from db.python.layers.analysis_runner import AnalysisRunnerLayer from db.python.tables.analysis_runner import AnalysisRunnerFilter from models.models.analysis_runner import AnalysisRunner, AnalysisRunnerInternal +from models.models.project import FullWriteAccessRoles, ReadAccessRoles router = APIRouter(prefix='/analysis-runner', tags=['analysis-runner']) @@ -32,13 +29,13 @@ async def create_analysis_runner_log( # pylint: disable=too-many-arguments output_path: str, hail_version: str | None = None, cwd: str | None = None, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> str: """Create a new analysis runner log""" alayer = AnalysisRunnerLayer(connection) - if not connection.project: + if not connection.project_id: raise ValueError('Project not set') analysis_id = await alayer.insert_analysis_runner_entry( @@ -58,7 +55,7 @@ async def create_analysis_runner_log( # pylint: disable=too-many-arguments batch_url=batch_url, submitting_user=submitting_user, meta=meta, - project=connection.project, + project=connection.project_id, audit_log_id=None, output_path=output_path, ) @@ -75,13 +72,13 @@ async def get_analysis_runner_logs( repository: str | None = None, access_level: str | None = None, environment: str | None = None, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> list[AnalysisRunner]: """Get analysis runner logs""" atable = AnalysisRunnerLayer(connection) - if not connection.project: + if not connection.project_id: raise ValueError('Project not set') filter_ = AnalysisRunnerFilter( @@ -90,9 +87,9 @@ async def get_analysis_runner_logs( repository=GenericFilter(eq=repository), access_level=GenericFilter(eq=access_level), environment=GenericFilter(eq=environment), - project=GenericFilter(eq=connection.project), + project=GenericFilter(eq=connection.project_id), ) logs = await atable.query(filter_) - return [log.to_external({connection.project: project}) for log in logs] + return [log.to_external({connection.project_id: project}) for log in logs] diff --git a/api/routes/assay.py b/api/routes/assay.py index 6680d7a34..e5092a8b8 100644 --- a/api/routes/assay.py +++ b/api/routes/assay.py @@ -2,14 +2,17 @@ from fastapi import APIRouter -from api.utils import get_project_readonly_connection -from api.utils.db import Connection, get_projectless_db_connection +from api.utils.db import ( + Connection, + get_project_db_connection, + get_projectless_db_connection, +) from db.python.filters import GenericFilter from db.python.layers.assay import AssayLayer from db.python.tables.assay import AssayFilter -from db.python.tables.project import ProjectPermissionsTable from models.base import SMBase from models.models.assay import AssayUpsert +from models.models.project import ReadAccessRoles from models.utils.sample_id_format import sample_id_transform_to_raw_list router = APIRouter(prefix='/assay', tags=['assay']) @@ -53,7 +56,8 @@ async def get_assay_by_id( '/{project}/external_id/{external_id}/details', operation_id='getAssayByExternalId' ) async def get_assay_by_external_id( - external_id: str, connection=get_project_readonly_connection + external_id: str, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get an assay by ONE of its external identifiers""" assay_layer = AssayLayer(connection) @@ -80,13 +84,13 @@ async def get_assays_by_criteria( ): """Get assays by criteria""" assay_layer = AssayLayer(connection) - pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if criteria.projects: - pids = await pt.get_project_ids_from_names_and_user( - connection.author, criteria.projects, readonly=True + project_list = connection.get_and_check_access_to_projects_for_names( + criteria.projects, allowed_roles=ReadAccessRoles ) + pids = [p.id for p in project_list] unwrapped_sample_ids: list[int] | None = None if criteria.sample_ids: diff --git a/api/routes/cohort.py b/api/routes/cohort.py index 0ea555a2f..4a5568926 100644 --- a/api/routes/cohort.py +++ b/api/routes/cohort.py @@ -1,10 +1,9 @@ from fastapi import APIRouter -from api.utils.db import Connection, HACK_get_project_contributor_connection +from api.utils.db import Connection, get_project_db_connection from db.python.layers.cohort import CohortLayer -from db.python.tables.project import ProjectPermissionsTable from models.models.cohort import CohortBody, CohortCriteria, CohortTemplate, NewCohort -from models.models.project import ProjectId +from models.models.project import ProjectId, ProjectMemberRole, ReadAccessRoles from models.utils.cohort_template_id_format import ( cohort_template_id_format, cohort_template_id_transform_to_raw, @@ -17,7 +16,9 @@ async def create_cohort_from_criteria( cohort_spec: CohortBody, cohort_criteria: CohortCriteria | None = None, - connection: Connection = HACK_get_project_contributor_connection, + connection: Connection = get_project_db_connection( + {ProjectMemberRole.writer, ProjectMemberRole.contributor} + ), dry_run: bool = False, ) -> NewCohort: """ @@ -35,24 +36,21 @@ async def create_cohort_from_criteria( internal_project_ids: list[ProjectId] = [] - if cohort_criteria: - if cohort_criteria.projects: - pt = ProjectPermissionsTable(connection) - projects = await pt.get_and_check_access_to_projects_for_names( - user=connection.author, - project_names=cohort_criteria.projects, - readonly=True, - ) - if projects: - internal_project_ids = [p.id for p in projects if p.id] + if cohort_criteria and cohort_criteria.projects: + projects = connection.get_and_check_access_to_projects_for_names( + project_names=cohort_criteria.projects, allowed_roles=ReadAccessRoles + ) + + internal_project_ids = [p.id for p in projects if p.id] template_id_raw = ( cohort_template_id_transform_to_raw(cohort_spec.template_id) if cohort_spec.template_id else None ) + assert connection.project_id cohort_output = await cohort_layer.create_cohort_from_criteria( - project_to_write=connection.project, + project_to_write=connection.project_id, description=cohort_spec.description, cohort_name=cohort_spec.name, dry_run=dry_run, @@ -70,7 +68,9 @@ async def create_cohort_from_criteria( @router.post('/{project}/cohort_template', operation_id='createCohortTemplate') async def create_cohort_template( template: CohortTemplate, - connection: Connection = HACK_get_project_contributor_connection, + connection: Connection = get_project_db_connection( + {ProjectMemberRole.writer, ProjectMemberRole.contributor} + ), ) -> str: """ Create a cohort template with the given name and sample/sequencing group IDs. @@ -83,20 +83,19 @@ async def create_cohort_template( criteria_project_ids: list[ProjectId] = [] if template.criteria.projects: - pt = ProjectPermissionsTable(connection) - projects_for_criteria = await pt.get_and_check_access_to_projects_for_names( - user=connection.author, + projects_for_criteria = connection.get_and_check_access_to_projects_for_names( project_names=template.criteria.projects, - readonly=False, + allowed_roles=ReadAccessRoles, ) - if projects_for_criteria: - criteria_project_ids = [p.id for p in projects_for_criteria if p.id] + criteria_project_ids = [p.id for p in projects_for_criteria if p.id] + assert connection.project_id cohort_raw_id = await cohort_layer.create_cohort_template( cohort_template=template.to_internal( - criteria_projects=criteria_project_ids, template_project=connection.project + criteria_projects=criteria_project_ids, + template_project=connection.project_id, ), - project=connection.project, + project=connection.project_id, ) return cohort_template_id_format(cohort_raw_id) diff --git a/api/routes/family.py b/api/routes/family.py index 0003b23f8..6aafda365 100644 --- a/api/routes/family.py +++ b/api/routes/family.py @@ -9,11 +9,10 @@ from pydantic import BaseModel from starlette.responses import StreamingResponse -from api.utils import get_projectless_db_connection from api.utils.db import ( Connection, - get_project_readonly_connection, - get_project_write_connection, + get_project_db_connection, + get_projectless_db_connection, ) from api.utils.export import ExportType from api.utils.extensions import guess_delimiter_by_upload_file_obj @@ -21,6 +20,7 @@ from db.python.layers.family import FamilyLayer, PedRow from db.python.tables.family import FamilyFilter from models.models.family import Family +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.utils.sample_id_format import sample_id_transform_to_raw_list router = APIRouter(prefix='/family', tags=['family']) @@ -41,7 +41,7 @@ async def import_pedigree( has_header: bool = False, create_missing_participants: bool = False, perform_sex_check: bool = True, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """Import a pedigree""" delimiter = guess_delimiter_by_upload_file_obj(file) @@ -75,7 +75,7 @@ async def get_pedigree( replace_with_family_external_ids: bool = True, include_header: bool = True, empty_participant_value: Optional[str] = None, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), include_participants_not_in_families: bool = False, ): """ @@ -87,9 +87,9 @@ async def get_pedigree( """ family_layer = FamilyLayer(connection) - assert connection.project + assert connection.project_id pedigree_dicts = await family_layer.get_pedigree( - project=connection.project, + project=connection.project_id, family_ids=internal_family_ids, replace_with_participant_external_ids=replace_with_participant_external_ids, replace_with_family_external_ids=replace_with_family_external_ids, @@ -120,7 +120,7 @@ async def get_pedigree( ] writer.writerows(pedigree_rows) - basefn = f'{connection.project}-{date.today().isoformat()}' + basefn = f'{connection.project_id}-{date.today().isoformat()}' if internal_family_ids: basefn += '-'.join(str(fm) for fm in internal_family_ids) @@ -142,7 +142,7 @@ async def get_pedigree( async def get_families( participant_ids: Optional[List[int]] = Query(None), sample_ids: Optional[List[str]] = Query(None), - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> List[Family]: """Get families for some project""" family_layer = FamilyLayer(connection) @@ -182,7 +182,7 @@ async def import_families( file: UploadFile = File(...), has_header: bool = True, delimiter: str | None = None, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """Import a family csv""" delimiter = guess_delimiter_by_upload_file_obj(file, default_delimiter=delimiter) diff --git a/api/routes/imports.py b/api/routes/imports.py index 7f8f84bf2..66dc0531e 100644 --- a/api/routes/imports.py +++ b/api/routes/imports.py @@ -4,12 +4,13 @@ from fastapi import APIRouter, File, UploadFile -from api.utils.db import Connection, get_project_write_connection +from api.utils.db import Connection, get_project_db_connection from api.utils.extensions import guess_delimiter_by_upload_file_obj from db.python.layers.participant import ( ExtraParticipantImporterHandler, ParticipantLayer, ) +from models.models.project import FullWriteAccessRoles router = APIRouter(prefix='/import', tags=['import']) @@ -23,7 +24,7 @@ async def import_individual_metadata_manifest( file: UploadFile = File(...), delimiter: Optional[str] = None, extra_participants_method: ExtraParticipantImporterHandler = ExtraParticipantImporterHandler.FAIL, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Import individual metadata manifest diff --git a/api/routes/participant.py b/api/routes/participant.py index d2b5e23b5..f6a9ddc22 100644 --- a/api/routes/participant.py +++ b/api/routes/participant.py @@ -6,16 +6,16 @@ from fastapi.params import Query from starlette.responses import JSONResponse, StreamingResponse -from api.utils import get_projectless_db_connection from api.utils.db import ( Connection, - get_project_readonly_connection, - get_project_write_connection, + get_project_db_connection, + get_projectless_db_connection, ) from api.utils.export import ExportType from db.python.layers.participant import ParticipantLayer from models.base import SMBase from models.models.participant import ParticipantUpsert +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.sequencing_group import sequencing_group_id_format router = APIRouter(prefix='/participant', tags=['participant']) @@ -25,7 +25,7 @@ '/{project}/fill-in-missing-participants', operation_id='fillInMissingParticipants' ) async def fill_in_missing_participants( - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Create a corresponding participant (if required) @@ -47,13 +47,13 @@ async def get_individual_metadata_template_for_seqr( external_participant_ids: list[str] | None = Query(default=None), # type: ignore[assignment] # pylint: disable=invalid-name replace_with_participant_external_ids: bool = True, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get individual metadata template for SEQR as a CSV""" participant_layer = ParticipantLayer(connection) - assert connection.project + assert connection.project_id resp = await participant_layer.get_seqr_individual_template( - project=connection.project, + project=connection.project_id, external_participant_ids=external_participant_ids, replace_with_participant_external_ids=replace_with_participant_external_ids, ) @@ -90,14 +90,14 @@ async def get_individual_metadata_template_for_seqr( async def get_id_map_by_external_ids( external_participant_ids: list[str], allow_missing: bool = False, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get ID map of participants, by external_id""" player = ParticipantLayer(connection) return await player.get_id_map_by_external_ids( external_participant_ids, allow_missing=allow_missing, - project=connection.project, + project=connection.project_id, ) @@ -121,7 +121,7 @@ async def get_external_participant_id_to_sequencing_group_id( sequencing_type: str = None, export_type: ExportType = ExportType.JSON, flip_columns: bool = False, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """ Get csv / tsv export of external_participant_id to sequencing_group_id @@ -136,10 +136,10 @@ async def get_external_participant_id_to_sequencing_group_id( :param flip_columns: Set to True when exporting for seqr """ player = ParticipantLayer(connection) - # this wants project ID (connection.project) - assert connection.project + # this wants project ID (connection.project_id) + assert connection.project_id m = await player.get_external_participant_id_to_internal_sequencing_group_id_map( - project=connection.project, sequencing_type=sequencing_type + project=connection.project_id, sequencing_type=sequencing_type ) rows = [[pid, sequencing_group_id_format(sgid)] for pid, sgid in m] @@ -189,7 +189,7 @@ async def update_participant( ) async def upsert_participants( participants: list[ParticipantUpsert], - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Upserts a list of participants with samples and sequences @@ -214,12 +214,13 @@ class QueryParticipantCriteria(SMBase): ) async def get_participants( criteria: QueryParticipantCriteria, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get participants, default ALL participants in project""" player = ParticipantLayer(connection) + assert connection.project_id participants = await player.get_participants( - project=connection.project, + project=connection.project_id, external_participant_ids=( criteria.external_participant_ids if criteria else None ), diff --git a/api/routes/project.py b/api/routes/project.py index 559938485..0df175436 100644 --- a/api/routes/project.py +++ b/api/routes/project.py @@ -1,29 +1,39 @@ -from typing import List +from fastapi import APIRouter, HTTPException -from fastapi import APIRouter - -from api.utils.db import Connection, get_projectless_db_connection +from api.utils.db import ( + Connection, + get_project_db_connection, + get_projectless_db_connection, +) from db.python.tables.project import ProjectPermissionsTable -from models.models.project import Project +from db.python.utils import Forbidden +from models.models.project import ( + FullWriteAccessRoles, + Project, + ProjectMemberRole, + ProjectMemberUpdate, + ReadAccessRoles, + project_member_role_names, +) router = APIRouter(prefix='/project', tags=['project']) -@router.get('/all', operation_id='getAllProjects', response_model=List[Project]) -async def get_all_projects(connection=get_projectless_db_connection): +@router.get('/all', operation_id='getAllProjects', response_model=list[Project]) +async def get_all_projects(connection: Connection = get_projectless_db_connection): """Get list of projects""" - ptable = ProjectPermissionsTable(connection) - return await ptable.get_all_projects(author=connection.author) + return connection.all_projects() -@router.get('/', operation_id='getMyProjects', response_model=List[str]) -async def get_my_projects(connection=get_projectless_db_connection): +@router.get('/', operation_id='getMyProjects', response_model=list[str]) +async def get_my_projects(connection: Connection = get_projectless_db_connection): """Get projects I have access to""" - ptable = ProjectPermissionsTable(connection) - projects = await ptable.get_projects_accessible_by_user( - author=connection.author, readonly=True - ) - return [p.name for p in projects] + return [ + p.name + for p in connection.projects_with_role( + ReadAccessRoles.union(FullWriteAccessRoles) + ) + ] @router.put('/', operation_id='createProject') @@ -37,6 +47,10 @@ async def create_project( Create a new project """ ptable = ProjectPermissionsTable(connection) + + # Creating a project requires the project creator permission + await ptable.check_project_creator_permissions(author=connection.author) + pid = await ptable.create_project( project_name=name, dataset_name=dataset, @@ -57,27 +71,44 @@ async def create_project( async def get_seqr_projects(connection: Connection = get_projectless_db_connection): """Get SM projects that should sync to seqr""" ptable = ProjectPermissionsTable(connection) - return await ptable.get_seqr_projects() + seqr_project_ids = await ptable.get_seqr_project_ids() + my_seqr_projects = [ + p for p in connection.all_projects() if p.id in seqr_project_ids + ] + + # Fail if user doesn't have access to all seqr projects. This endpoint is used + # for joint-calling where we would want to include all seqr projects and know if + # any are missing, so it is important to raise an error here rather than just + # excluding projects due to permission issues + if len(my_seqr_projects) != len(seqr_project_ids): + raise Forbidden('The current user does not have access to all seqr projects') + + return my_seqr_projects @router.post('/{project}/update', operation_id='updateProject') async def update_project( - project: str, project_update_model: dict, - connection: Connection = get_projectless_db_connection, + connection: Connection = get_project_db_connection( + {ProjectMemberRole.project_admin} + ), ): """Update a project by project name""" ptable = ProjectPermissionsTable(connection) + + project = connection.project + assert project return await ptable.update_project( - project_name=project, update=project_update_model, author=connection.author + project_name=project.name, update=project_update_model, author=connection.author ) @router.delete('/{project}', operation_id='deleteProjectData') async def delete_project_data( - project: str, delete_project: bool = False, - connection: Connection = get_projectless_db_connection, + connection: Connection = get_project_db_connection( + {ProjectMemberRole.project_admin} + ), ): """ Delete all data in a project by project name. @@ -85,11 +116,12 @@ async def delete_project_data( Requires READ access + project-creator permissions """ ptable = ProjectPermissionsTable(connection) - p_obj = await ptable.get_and_check_access_to_project_for_name( - user=connection.author, project_name=project, readonly=False - ) + + assert connection.project + success = await ptable.delete_project_data( - project_id=p_obj.id, delete_project=delete_project, author=connection.author + project_id=connection.project.id, + delete_project=delete_project, ) return {'success': success} @@ -97,20 +129,24 @@ async def delete_project_data( @router.patch('/{project}/members', operation_id='updateProjectMembers') async def update_project_members( - project: str, - members: list[str], - readonly: bool, - connection: Connection = get_projectless_db_connection, + members: list[ProjectMemberUpdate], + connection: Connection = get_project_db_connection( + {ProjectMemberRole.project_member_admin} + ), ): """ Update project members for specific read / write group. Not that this is protected by access to a specific access group """ ptable = ProjectPermissionsTable(connection) - await ptable.set_group_members( - group_name=ptable.get_project_group_name(project, readonly=readonly), - members=members, - author=connection.author, - ) + + for member in members: + for role in member.roles: + if role not in project_member_role_names: + raise HTTPException( + 400, f'Role {role} is not valid for member {member.member}' + ) + assert connection.project + await ptable.set_project_members(project=connection.project, members=members) return {'success': True} diff --git a/api/routes/sample.py b/api/routes/sample.py index 982e89798..b265239f2 100644 --- a/api/routes/sample.py +++ b/api/routes/sample.py @@ -2,13 +2,12 @@ from api.utils.db import ( Connection, - get_project_readonly_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from db.python.layers.sample import SampleLayer -from db.python.tables.project import ProjectPermissionsTable from models.base import SMBase +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.sample import SampleUpsert from models.utils.sample_id_format import ( # Sample, sample_id_format, @@ -23,7 +22,8 @@ @router.put('/{project}/', response_model=str, operation_id='createSample') async def create_sample( - sample: SampleUpsert, connection: Connection = get_project_write_connection + sample: SampleUpsert, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> str: """Creates a new sample, and returns the internal sample ID""" st = SampleLayer(connection) @@ -37,7 +37,7 @@ async def create_sample( ) async def upsert_samples( samples: list[SampleUpsert], - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Upserts a list of samples with sequencing-groups, @@ -62,7 +62,7 @@ async def upsert_samples( async def get_sample_id_map_by_external( external_ids: list[str], allow_missing: bool = False, - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get map of sample IDs, { [externalId]: internal_sample_id }""" st = SampleLayer(connection) @@ -91,12 +91,14 @@ async def get_sample_id_map_by_internal( '/{project}/id-map/internal/all', operation_id='getAllSampleIdMapByInternal' ) async def get_all_sample_id_map_by_internal( - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get map of ALL sample IDs, { [internal_id]: external_sample_id }""" st = SampleLayer(connection) - assert connection.project - result = await st.get_all_sample_id_map_by_internal_ids(project=connection.project) + assert connection.project_id + result = await st.get_all_sample_id_map_by_internal_ids( + project=connection.project_id + ) return {sample_id_format(k): v for k, v in result.items()} @@ -107,12 +109,15 @@ async def get_all_sample_id_map_by_internal( operation_id='getSampleByExternalId', ) async def get_sample_by_external_id( - external_id: str, connection: Connection = get_project_readonly_connection + external_id: str, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get sample by external ID""" st = SampleLayer(connection) - assert connection.project - result = await st.get_single_by_external_id(external_id, project=connection.project) + assert connection.project_id + result = await st.get_single_by_external_id( + external_id, project=connection.project_id + ) return result.to_external() @@ -136,12 +141,12 @@ async def get_samples( """ st = SampleLayer(connection) - pt = ProjectPermissionsTable(connection) pids: list[int] | None = None if criteria.project_ids: - pids = await pt.get_project_ids_from_names_and_user( - connection.author, criteria.project_ids, readonly=True + projects = connection.get_and_check_access_to_projects_for_names( + criteria.project_ids, allowed_roles=ReadAccessRoles ) + pids = [p.id for p in projects] sample_ids_raw = ( sample_id_transform_to_raw_list(criteria.sample_ids) @@ -155,7 +160,6 @@ async def get_samples( participant_ids=criteria.participant_ids, project_ids=pids, active=criteria.active, - check_project_ids=True, ) return [r.to_external() for r in result] diff --git a/api/routes/sequencing_groups.py b/api/routes/sequencing_groups.py index 9a35ee1e4..09c3c44a1 100644 --- a/api/routes/sequencing_groups.py +++ b/api/routes/sequencing_groups.py @@ -5,11 +5,11 @@ from api.utils.db import ( Connection, - get_project_readonly_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from db.python.layers.sequencing_group import SequencingGroupLayer +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.sequencing_group import SequencingGroupUpsertInternal from models.utils.sample_id_format import sample_id_format from models.utils.sequencing_group_id_format import ( # Sample, @@ -42,7 +42,7 @@ async def get_sequencing_group( @router.get('/project/{project}', operation_id='getAllSequencingGroupIdsBySampleByType') async def get_all_sequencing_group_ids_by_sample_by_type( - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> dict[str, dict[str, list[str]]]: """Creates a new sample, and returns the internal sample ID""" st = SequencingGroupLayer(connection) @@ -60,7 +60,7 @@ async def get_all_sequencing_group_ids_by_sample_by_type( async def update_sequencing_group( sequencing_group_id: str, sequencing_group: SequencingGroupMetaUpdateModel, - connection: Connection = get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ) -> bool: """Update the meta fields of a sequencing group""" st = SequencingGroupLayer(connection) diff --git a/api/routes/web.py b/api/routes/web.py index 58974e4a3..c601b49d9 100644 --- a/api/routes/web.py +++ b/api/routes/web.py @@ -14,8 +14,7 @@ from api.utils.db import ( Connection, - get_project_readonly_connection, - get_project_write_connection, + get_project_db_connection, get_projectless_db_connection, ) from api.utils.export import ExportType @@ -23,10 +22,10 @@ from db.python.layers.search import SearchLayer from db.python.layers.seqr import SeqrLayer from db.python.layers.web import WebLayer -from db.python.tables.project import ProjectPermissionsTable from models.base import SMBase from models.enums.web import MetaSearchEntityPrefix, SeqrDatasetType from models.models.participant import NestedParticipant +from models.models.project import FullWriteAccessRoles, ReadAccessRoles from models.models.search import SearchResponse from models.models.web import ( ProjectParticipantGridField, @@ -50,7 +49,7 @@ class SearchResponseModel(SMBase): operation_id='getProjectSummary', ) async def get_project_summary( - connection: Connection = get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ) -> ProjectSummary: """Creates a new sample, and returns the internal sample ID""" st = WebLayer(connection) @@ -66,7 +65,7 @@ async def get_project_summary( operation_id='getProjectParticipantsFilterSchema', ) async def get_project_project_participants_filter_schema( - _=get_project_readonly_connection, + _=get_project_db_connection(ReadAccessRoles), ): """Get project summary (from query) with some limit""" return ProjectParticipantGridFilter.model_json_schema() @@ -81,15 +80,15 @@ async def get_project_participants_grid_with_limit( limit: int, query: ProjectParticipantGridFilter, skip: int = 0, - connection=get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get project summary (from query) with some limit""" - if not connection.project: + if not connection.project_id: raise ValueError('No project was detected through the authentication') wlayer = WebLayer(connection) - pfilter = query.to_internal(project=connection.project) + pfilter = query.to_internal(project=connection.project_id) participants, pcount = await asyncio.gather( wlayer.query_participants(pfilter, limit=limit, skip=skip), @@ -117,15 +116,15 @@ async def export_project_participants( export_type: ExportType, query: ProjectParticipantGridFilter, fields: ExportProjectParticipantFields | None = None, - connection=get_project_readonly_connection, + connection: Connection = get_project_db_connection(ReadAccessRoles), ): """Get project summary (from query) with some limit""" - if not connection.project: + if not connection.project_id: raise ValueError('No project was detected through the authentication') wlayer = WebLayer(connection) - pfilter = query.to_internal(project=connection.project) + pfilter = query.to_internal(project=connection.project_id) participants_internal = await wlayer.query_participants(pfilter, limit=None) participants = [p.to_external() for p in participants_internal] @@ -251,15 +250,15 @@ def get_visible_fields(key: MetaSearchEntityPrefix): @router.get( '/search', response_model=SearchResponseModel, operation_id='searchByKeyword' ) -async def search_by_keyword(keyword: str, connection=get_projectless_db_connection): +async def search_by_keyword( + keyword: str, connection: Connection = get_projectless_db_connection +): """ This searches the keyword, in families, participants + samples in the projects that you are a part of (automatically). """ - pt = ProjectPermissionsTable(connection) - projects = await pt.get_projects_accessible_by_user( - connection.author, readonly=True - ) + # raise ValueError("Test") + projects = connection.all_projects() pmap = {p.id: p for p in projects} responses = await SearchLayer(connection).search( keyword, project_ids=[p for p in pmap.keys() if p] @@ -289,7 +288,7 @@ async def sync_seqr_project( sync_saved_variants: bool = True, sync_cram_map: bool = True, post_slack_notification: bool = True, - connection=get_project_write_connection, + connection: Connection = get_project_db_connection(FullWriteAccessRoles), ): """ Sync a metamist project with its seqr project (for a specific sequence type) diff --git a/api/server.py b/api/server.py index 82e0684ff..cfee85348 100644 --- a/api/server.py +++ b/api/server.py @@ -17,15 +17,15 @@ PROFILE_REQUESTS, PROFILE_REQUESTS_OUTPUT, SKIP_DATABASE_CONNECTION, + SM_ENVIRONMENT, ) -from api.utils import get_openapi_schema_func from api.utils.exceptions import determine_code_from_error +from api.utils.openapi import get_openapi_schema_func from db.python.connect import SMConnections -from db.python.tables.project import is_all_access from db.python.utils import get_logger # This tag is automatically updated by bump2version -_VERSION = '7.1.1' +_VERSION = '7.2.0' logger = get_logger() @@ -92,7 +92,7 @@ async def profile_request(request: Request, call_next): return resp -if is_all_access(): +if SM_ENVIRONMENT == 'local': app.add_middleware( CORSMiddleware, allow_origins=['*'], diff --git a/api/settings.py b/api/settings.py index ef3ea9eeb..072b262d8 100644 --- a/api/settings.py +++ b/api/settings.py @@ -9,9 +9,8 @@ LOG_DATABASE_QUERIES = ( os.getenv('SM_LOG_DATABASE_QUERIES', 'false').lower() in TRUTH_SET ) -_ALLOW_ALL_ACCESS: bool = os.getenv('SM_ALLOWALLACCESS', 'n').lower() in TRUTH_SET _DEFAULT_USER = os.getenv('SM_LOCALONLY_DEFAULTUSER') -SM_ENVIRONMENT = os.getenv('SM_ENVIRONMENT', 'local').lower() +SM_ENVIRONMENT = os.getenv('SM_ENVIRONMENT', 'production').lower() SKIP_DATABASE_CONNECTION = bool(os.getenv('SM_SKIP_DATABASE_CONNECTION')) PROFILE_REQUESTS = os.getenv('SM_PROFILE_REQUESTS', 'false').lower() in TRUTH_SET PROFILE_REQUESTS_OUTPUT = os.getenv('SM_PROFILE_REQUESTS_OUTPUT', 'text').lower() @@ -61,22 +60,11 @@ def get_default_user() -> str | None: """Determine if a default user is available""" - if is_all_access() and _DEFAULT_USER: + if SM_ENVIRONMENT in ('local', 'test') and _DEFAULT_USER: return _DEFAULT_USER return None -def is_all_access() -> bool: - """Does SM have full access""" - return _ALLOW_ALL_ACCESS - - -def set_all_access(access: bool): - """Set full_access for future use""" - global _ALLOW_ALL_ACCESS - _ALLOW_ALL_ACCESS = access - - @lru_cache def get_slack_token(allow_empty=False): """Get slack token""" diff --git a/api/utils/__init__.py b/api/utils/__init__.py index fe5ae4cdd..089785938 100644 --- a/api/utils/__init__.py +++ b/api/utils/__init__.py @@ -1,15 +1,8 @@ """Importing GCP libraries""" + from collections import defaultdict from typing import Callable, Iterable, TypeVar -from .db import ( - authenticate, - get_project_readonly_connection, - get_project_write_connection, - get_projectless_db_connection, -) -from .openapi import get_openapi_schema_func - T = TypeVar('T') X = TypeVar('X') diff --git a/api/utils/db.py b/api/utils/db.py index 608a49888..aa0cb38a0 100644 --- a/api/utils/db.py +++ b/api/utils/db.py @@ -1,6 +1,7 @@ import json import logging from os import getenv +from typing import Any from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -11,7 +12,7 @@ from api.utils.gcp import email_from_id_token from db.python.connect import Connection, SMConnections from db.python.gcp_connect import BqConnection, PubSubConnection -from db.python.tables.project import ProjectPermissionsTable +from models.models.project import ProjectMemberRole EXPECTED_AUDIENCE = getenv('SM_OAUTHAUDIENCE') @@ -30,7 +31,7 @@ def get_ar_guid(request: Request) -> str | None: return request.headers.get('sm-ar-guid') -def get_extra_audit_log_values(request: Request) -> dict | None: +def get_extra_audit_log_values(request: Request) -> dict[str, Any] | None: """Get a JSON encoded dictionary from the 'sm-extra-values' header if it exists""" headers = request.headers.get('sm-extra-values') if not headers: @@ -78,86 +79,43 @@ def authenticate( raise HTTPException(status_code=401, detail='Not authenticated :(') -async def dependable_get_write_project_connection( - project: str, - request: Request, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), - on_behalf_of: str | None = Depends(get_on_behalf_of), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {'path': request.url.path} - if request.client: - meta['ip'] = request.client.host - if extra_values: - meta.update(extra_values) - - return await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - readonly=False, - ar_guid=ar_guid, - on_behalf_of=on_behalf_of, - meta=meta, - ) - - -async def HACK_dependable_contributor_project_connection( - project: str, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {} - if extra_values: - meta.update(extra_values) - - meta['role'] = 'contributor' - - # hack by making it appear readonly - connection = await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - readonly=True, - on_behalf_of=None, - ar_guid=ar_guid, - meta=meta, - ) - - # then hack it so - connection.readonly = False - - return connection - - -async def dependable_get_readonly_project_connection( - project: str, - author: str = Depends(authenticate), - ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), -) -> Connection: - """FastAPI handler for getting connection WITH project""" - meta = {} - if extra_values: - meta.update(extra_values) +def dependable_get_project_db_connection(allowed_roles: set[ProjectMemberRole]): + """Return a partially applied dependable db connection with allowed roles applied""" + + async def dependable_project_db_connection( + project: str, + request: Request, + author: str = Depends(authenticate), + ar_guid: str = Depends(get_ar_guid), + extra_values: dict[str, Any] | None = Depends(get_extra_audit_log_values), + on_behalf_of: str | None = Depends(get_on_behalf_of), + ) -> Connection: + """FastAPI handler for getting connection WITH project""" + meta = {'path': request.url.path} + if request.client: + meta['ip'] = request.client.host + + if extra_values: + meta.update(extra_values) + + return await SMConnections.get_connection_with_project( + project_name=project, + author=author, + allowed_roles=allowed_roles, + on_behalf_of=on_behalf_of, + ar_guid=ar_guid, + meta=meta, + ) - return await ProjectPermissionsTable.get_project_connection( - project_name=project, - author=author, - readonly=True, - on_behalf_of=None, - ar_guid=ar_guid, - meta=meta, - ) + return dependable_project_db_connection async def dependable_get_connection( request: Request, author: str = Depends(authenticate), ar_guid: str = Depends(get_ar_guid), - extra_values: dict | None = Depends(get_extra_audit_log_values), + extra_values: dict[str, Any] | None = Depends(get_extra_audit_log_values), + on_behalf_of: str | None = Depends(get_on_behalf_of), ): """FastAPI handler for getting connection withOUT project""" meta = {'path': request.url.path} @@ -168,7 +126,7 @@ async def dependable_get_connection( meta.update(extra_values) return await SMConnections.get_connection_no_project( - author, ar_guid=ar_guid, meta=meta + author, ar_guid=ar_guid, meta=meta, on_behalf_of=on_behalf_of ) @@ -184,7 +142,7 @@ async def dependable_get_pubsub_connection( return await PubSubConnection.get_connection_no_project(author, topic) -def validate_iap_jwt_and_get_email(iap_jwt, audience): +def validate_iap_jwt_and_get_email(iap_jwt: str, audience: str): """ Validate an IAP JWT and return email Source: https://cloud.google.com/iap/docs/signed-headers-howto @@ -207,11 +165,13 @@ def validate_iap_jwt_and_get_email(iap_jwt, audience): get_author = Depends(authenticate) -get_project_readonly_connection = Depends(dependable_get_readonly_project_connection) -HACK_get_project_contributor_connection = Depends( - HACK_dependable_contributor_project_connection -) -get_project_write_connection = Depends(dependable_get_write_project_connection) + + +def get_project_db_connection(allowed_roles: set[ProjectMemberRole]): + """Get a project db connection with allowed roles applied""" + return Depends(dependable_get_project_db_connection(allowed_roles)) + + get_projectless_db_connection = Depends(dependable_get_connection) get_projectless_bq_connection = Depends(dependable_get_bq_connection) get_projectless_pubsub_connection = Depends(dependable_get_pubsub_connection) diff --git a/db/project.xml b/db/project.xml index 85917b096..adea9622a 100644 --- a/db/project.xml +++ b/db/project.xml @@ -1449,4 +1449,69 @@ + + + SET @@system_versioning_alter_history = 1; + + + + + + + + + + + + + + + + + + + + + ALTER TABLE project_member ADD SYSTEM VERSIONING; + + + + + + -- copy across read and write roles from existing project table + INSERT INTO project_member ( + SELECT + p.id AS project_id, + gm.member as member, + 'reader' AS role, + coalesce(gm.audit_log_id, p.audit_log_id, g.audit_log_id) as audit_log_id + FROM project p + JOIN `group` g on g.id = p.read_group_id + JOIN group_member gm on gm.group_id = g.id + WHERE p.read_group_id IS NOT NULL + + UNION ALL + + SELECT + p.id AS project_id, + gm.member as member, + 'writer' AS role, + coalesce(gm.audit_log_id, p.audit_log_id, g.audit_log_id) as audit_log_id + FROM project p + JOIN `group` g on g.id = p.write_group_id + JOIN group_member gm on gm.group_id = g.id + WHERE p.read_group_id IS NOT NULL + ) + + + + + + + + diff --git a/db/python/connect.py b/db/python/connect.py index 5456c511f..38e3f81d6 100644 --- a/db/python/connect.py +++ b/db/python/connect.py @@ -8,11 +8,19 @@ import json import logging import os +from typing import Iterable import databases from api.settings import LOG_DATABASE_QUERIES -from db.python.utils import InternalError +from db.python.tables.project import ProjectPermissionsTable +from db.python.utils import ( + InternalError, + NoProjectAccess, + NotFoundError, + ProjectDoesNotExist, +) +from models.models.project import Project, ProjectId, ProjectMemberRole logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -50,31 +58,175 @@ class Connection: def __init__( self, connection: databases.Database, - project: int | None, + project: Project | None, + project_id_map: dict[ProjectId, Project], + project_name_map: dict[str, Project], author: str, on_behalf_of: str | None, - readonly: bool, ar_guid: str | None, meta: dict[str, str] | None = None, ): - self.connection: databases.Database = connection - self.project: int | None = project + self.__connection: databases.Database = connection + self.__project: Project | None = project + self.__project_id_map = project_id_map + self.__project_name_map = project_name_map self.author: str = author self.on_behalf_of: str | None = on_behalf_of - self.readonly: bool = readonly self.ar_guid: str | None = ar_guid self.meta = meta self._audit_log_id: int | None = None self._audit_log_lock = asyncio.Lock() - async def audit_log_id(self): - """Get audit_log ID for write operations, cached per connection""" - if self.readonly: + @property + def connection(self): + """Public getter for private project class variable""" + return self.__connection + + @property + def project(self): + """Public getter for private project class variable""" + return self.__project + + @property + def project_id_map(self): + """Public getter for private project_id_map class variable""" + return self.__project_id_map + + @property + def project_name_map(self): + """Public getter for private project_name_map class variable""" + return self.__project_name_map + + @property + def project_id(self): + """Safely get the project id from the project model attached to the connection""" + return self.project.id if self.project is not None else None + + def all_projects(self): + """Return all projects that the current user has access to""" + return list(self.project_id_map.values()) + + def projects_with_role(self, allowed_roles: set[ProjectMemberRole]): + """Return all projects that the current user has access to""" + return [p for p in self.project_id_map.values() if p.roles & allowed_roles] + + def get_and_check_access_to_projects( + self, projects: Iterable[Project], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + specified projects. Raise an error if they do not. + """ + # projects that the user has some access to, but not the required access + disallowed_projects = [p for p in projects if not p.roles & allowed_roles] + + if disallowed_projects: + raise NoProjectAccess( + [p.name for p in disallowed_projects], + allowed_roles=[r.name for r in allowed_roles], + author=self.author, + ) + + return projects + + def get_and_check_access_to_projects_for_ids( + self, project_ids: Iterable[ProjectId], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project ids. Raise an error if they do not. + Also raise an error if any of the specified project ids doesn't exist or the + current user has no access to it. Return the matching projects + """ + projects = [ + self.project_id_map[id] for id in project_ids if id in self.project_id_map + ] + + # Check if any of the provided ids aren't valid project ids, or the user has + # no access to them at all. A NotFoundError is raised here rather than a + # Forbidden so as to not leak the existence of the project to those with no access. + missing_project_ids = set(project_ids) - set(p.id for p in projects) + if missing_project_ids: + missing_project_ids_str = ', '.join([str(p) for p in missing_project_ids]) + raise NotFoundError( + f'Could not find projects with ids: {missing_project_ids_str}' + ) + + return self.get_and_check_access_to_projects(projects, allowed_roles) + + def check_access_to_projects_for_ids( + self, project_ids: Iterable[ProjectId], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project ids. Raise an error if they do not. + Also raise an error if any of the specified project ids doesn't exist or the + current user has no access to it. Returns None + """ + self.get_and_check_access_to_projects_for_ids(project_ids, allowed_roles) + + def get_and_check_access_to_projects_for_names( + self, project_names: Iterable[str], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project names. Raise an error if they do not. + Also raise an error if any of the specified project names doesn't exist or + the current user has no access to it. Return the matching projects + """ + projects = [ + self.project_name_map[name] + for name in project_names + if name in self.project_name_map + ] + + # Check if any of the provided names aren't valid project names, or the user has + # no access to them at all. A NotFoundError is raised here rather than a + # Forbidden so as to not leak the existence of the project to those with no access. + missing_project_names = set(project_names) - set(p.name for p in projects) + + if missing_project_names: + missing_project_names_str = ', '.join( + [f'"{str(p)}"' for p in missing_project_names] + ) + raise NotFoundError( + f'Could not find projects with names: {missing_project_names_str}' + ) + + return self.get_and_check_access_to_projects(projects, allowed_roles) + + def check_access_to_projects_for_names( + self, project_names: Iterable[str], allowed_roles: set[ProjectMemberRole] + ): + """ + Check if the current user has _any_ of the specified roles in _all_ of the + projects based on the specified project names. Raise an error if they do not. + Also raise an error if any of the specified project names doesn't exist or + the current user has no access to it. Returns None + """ + self.get_and_check_access_to_projects_for_names(project_names, allowed_roles) + + def check_access(self, allowed_roles: set[ProjectMemberRole]): + """ + Check if the current user has the specified role within the project that is + attached to the connection. If there is no project attached to the connection + this will raise an error. + """ + if self.project is None: raise InternalError( - 'Trying to get a audit_log ID, but not a write connection' + 'Connection was expected to have a project attached, but did not' + ) + if not allowed_roles & self.project.roles: + raise NoProjectAccess( + project_names=[self.project.name], + author=self.author, + allowed_roles=[r.name for r in allowed_roles], ) + async def audit_log_id(self): + """Get audit_log ID for write operations, cached per connection""" + async with self._audit_log_lock: if not self._audit_log_id: @@ -89,26 +241,37 @@ async def audit_log_id(self): on_behalf_of=self.on_behalf_of, ar_guid=self.ar_guid, comment=None, - project=self.project, + project=self.project_id, meta=self.meta, ) return self._audit_log_id - def assert_requires_project(self): - """Assert the project is set, or return an exception""" - if self.project is None: - raise InternalError( - 'An internal error has occurred when passing the project context, ' - 'please send this stacktrace to your system administrator' - ) + async def refresh_projects(self): + """ + Re-fetch the projects for the current user and update the connection. + This only really needs to be run after project member updates or project + creation, and really only for tests. The API fetches projects on each request + so subsequent requests after updates will already have up-to-date data + """ + conn = self.connection + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + + project_id_map, project_name_map = await pt.get_projects_accessible_by_user( + user=self.author + ) + self.__project_id_map = project_id_map + self.__project_name_map = project_name_map + + if self.project_id: + self.__project = self.project_id_map.get(self.project_id) class DatabaseConfiguration(abc.ABC): """Base class for DatabaseConfiguration""" @abc.abstractmethod - def get_connection_string(self): + def get_connection_string(self) -> str: """Get connection string""" raise NotImplementedError @@ -116,7 +279,7 @@ def get_connection_string(self): class ConnectionStringDatabaseConfiguration(DatabaseConfiguration): """Database Configuration that takes a literal DatabaseConfiguration""" - def __init__(self, connection_string): + def __init__(self, connection_string: str): self.connection_string = connection_string def get_connection_string(self): @@ -128,11 +291,11 @@ class CredentialedDatabaseConfiguration(DatabaseConfiguration): def __init__( self, - dbname, - host=None, - port=None, - username=None, - password=None, + dbname: str, + host: str | None = None, + port: str | None = None, + username: str | None = None, + password: str | None = None, ): self.dbname = dbname self.host = host @@ -156,6 +319,8 @@ def get_connection_string(self): """Prepares the connection string for mysql / mariadb""" _host = self.host or 'localhost' + + assert self.username u_p = self.username if self.password: @@ -176,7 +341,7 @@ def get_connection_string(self): class SMConnections: """Contains useful functions for connecting to the database""" - _credentials: DatabaseConfiguration | None = None + _credentials: CredentialedDatabaseConfiguration | None = None @staticmethod def _get_config(): @@ -232,9 +397,48 @@ async def get_made_connection(): return conn + @staticmethod + async def get_connection_with_project( + *, + author: str, + project_name: str, + allowed_roles: set[ProjectMemberRole], + ar_guid: str, + on_behalf_of: str | None = None, + meta: dict[str, str] | None = None, + ): + """Get a db connection from a project and user""" + # maybe it makes sense to perform permission checks here too + logger.debug(f'Authenticate connection to {project_name} with {author!r}') + + conn = await SMConnections.get_made_connection() + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + + project_id_map, project_name_map = await pt.get_projects_accessible_by_user( + user=author + ) + + if project_name not in project_name_map: + raise ProjectDoesNotExist(project_name) + + connection = Connection( + connection=conn, + author=author, + project=project_name_map[project_name], + project_id_map=project_id_map, + project_name_map=project_name_map, + on_behalf_of=on_behalf_of, + ar_guid=ar_guid, + meta=meta, + ) + + connection.check_access(allowed_roles) + + return connection + @staticmethod async def get_connection_no_project( - author: str, ar_guid: str, meta: dict[str, str] + author: str, ar_guid: str, meta: dict[str, str], on_behalf_of: str | None ): """Get a db connection from a project and user""" # maybe it makes sense to perform permission checks here too @@ -242,15 +446,18 @@ async def get_connection_no_project( conn = await SMConnections.get_made_connection() - # we don't authenticate project-less connection, but rely on the - # the endpoint to validate the resources + pt = ProjectPermissionsTable(connection=None, database_connection=conn) + project_id_map, project_name_map = await pt.get_projects_accessible_by_user( + user=author + ) return Connection( connection=conn, author=author, project=None, - on_behalf_of=None, + on_behalf_of=on_behalf_of, ar_guid=ar_guid, - readonly=False, meta=meta, + project_id_map=project_id_map, + project_name_map=project_name_map, ) diff --git a/db/python/layers/analysis.py b/db/python/layers/analysis.py index 60f216d85..f136d49b1 100644 --- a/db/python/layers/analysis.py +++ b/db/python/layers/analysis.py @@ -21,7 +21,7 @@ ProportionalDateTemporalMethod, SequencingGroupInternal, ) -from models.models.project import ProjectId +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles from models.models.sequencing_group import SequencingGroupInternalId ES_ANALYSIS_OBJ_INTRO_DATE = datetime.date(2022, 6, 21) @@ -54,13 +54,13 @@ def __init__(self, connection: Connection): # GETS - async def get_analysis_by_id(self, analysis_id: int, check_project_id=True): + async def get_analysis_by_id(self, analysis_id: int): """Get analysis by ID""" project, analysis = await self.at.get_analysis_by_id(analysis_id) - if check_project_id: - await self.ptable.check_access_to_project_id( - self.author, project, readonly=True - ) + + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return analysis @@ -115,17 +115,17 @@ async def get_sample_cram_path_map_for_seqr( participant_ids=participant_ids, ) - async def query(self, filter_: AnalysisFilter, check_project_ids=True): + async def query(self, filter_: AnalysisFilter) -> list[AnalysisInternal]: """Query analyses""" analyses = await self.at.query(filter_) if not analyses: return [] - if check_project_ids and not filter_.project: - await self.ptable.check_access_to_project_ids( - self.author, set(a.project for a in analyses), readonly=True - ) + self.connection.check_access_to_projects_for_ids( + set(a.project for a in analyses if a.project is not None), + allowed_roles=ReadAccessRoles, + ) return analyses @@ -156,8 +156,8 @@ async def get_cram_size_proportionate_map( if start_date < datetime.date(2020, 1, 1): raise ValueError(f'start_date ({start_date}) must be after 2020-01-01') - project_objs = await self.ptable.get_and_check_access_to_projects_for_ids( - project_ids=projects, user=self.author, readonly=True + project_objs = self.connection.get_and_check_access_to_projects_for_ids( + project_ids=projects, allowed_roles=ReadAccessRoles ) project_name_map = {p.id: p.name for p in project_objs} @@ -554,14 +554,13 @@ async def create_analysis( ) async def add_sequencing_groups_to_analysis( - self, analysis_id: int, sequencing_group_ids: list[int], check_project_id=True + self, analysis_id: int, sequencing_group_ids: list[int] ): """Add samples to an analysis (through the linked table)""" - if check_project_id: - project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False - ) + project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) return await self.at.add_sequencing_groups_to_analysis( analysis_id=analysis_id, sequencing_group_ids=sequencing_group_ids @@ -573,16 +572,14 @@ async def update_analysis( status: AnalysisStatus, meta: dict[str, Any] = None, output: str | None = None, - check_project_id=True, ): """ Update the status of an analysis, set timestamp_completed if relevant """ - if check_project_id: - project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False - ) + project_ids = await self.at.get_project_ids_for_analysis_ids([analysis_id]) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) await self.at.update_analysis( analysis_id=analysis_id, diff --git a/db/python/layers/analysis_runner.py b/db/python/layers/analysis_runner.py index 7cb504847..3422ee5c6 100644 --- a/db/python/layers/analysis_runner.py +++ b/db/python/layers/analysis_runner.py @@ -2,6 +2,7 @@ from db.python.layers.base import BaseLayer from db.python.tables.analysis_runner import AnalysisRunnerFilter, AnalysisRunnerTable from models.models.analysis_runner import AnalysisRunnerInternal +from models.models.project import ReadAccessRoles class AnalysisRunnerLayer(BaseLayer): @@ -15,19 +16,18 @@ def __init__(self, connection: Connection): # GETS async def query( - self, filter_: AnalysisRunnerFilter, check_project_ids: bool = True + self, filter_: AnalysisRunnerFilter ) -> list[AnalysisRunnerInternal]: """Get analysis runner logs""" logs = await self.at.query(filter_) if not logs: return [] - if check_project_ids: - project_ids = set(log.project for log in logs) + project_ids = set(log.project for log in logs) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=ReadAccessRoles + ) return logs diff --git a/db/python/layers/assay.py b/db/python/layers/assay.py index 89ab19a72..e04ff404d 100644 --- a/db/python/layers/assay.py +++ b/db/python/layers/assay.py @@ -5,6 +5,7 @@ from db.python.tables.sample import SampleTable from db.python.utils import NoOpAenter from models.models.assay import AssayInternal, AssayUpsertInternal +from models.models.project import FullWriteAccessRoles, ReadAccessRoles class AssayLayer(BaseLayer): @@ -16,7 +17,7 @@ def __init__(self, connection: Connection): self.sampt: SampleTable = SampleTable(connection) # GET - async def query(self, filter_: AssayFilter = None, check_project_id=True): + async def query(self, filter_: AssayFilter = None): """Query for samples""" projects, assays = await self.seqt.query(filter_) @@ -24,23 +25,19 @@ async def query(self, filter_: AssayFilter = None, check_project_id=True): if not assays: return [] - if check_project_id: - await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + project_ids=projects, allowed_roles=ReadAccessRoles + ) return assays - async def get_assay_by_id( - self, assay_id: int, check_project_id=True - ) -> AssayInternal: + async def get_assay_by_id(self, assay_id: int) -> AssayInternal: """Get assay by ID""" project, assay = await self.seqt.get_assay_by_id(assay_id) - if check_project_id: - await self.ptable.check_access_to_project_id( - self.author, project, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return assay @@ -58,10 +55,7 @@ async def get_assay_by_external_id( return assay async def get_assays_for_sequencing_group_ids( - self, - sequencing_group_ids: list[int], - filter_: AssayFilter | None = None, - check_project_ids=True, + self, sequencing_group_ids: list[int], filter_: AssayFilter | None = None ) -> dict[int, list[AssayInternal]]: """Get assays for a list of sequencing group IDs""" if not sequencing_group_ids: @@ -75,17 +69,16 @@ async def get_assays_for_sequencing_group_ids( if not assays: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return assays # region UPSERTs async def upsert_assay( - self, assay: AssayUpsertInternal, check_project_id=True, open_transaction=True + self, assay: AssayUpsertInternal, open_transaction=True ) -> AssayUpsertInternal: """Upsert a single assay""" @@ -96,8 +89,9 @@ async def upsert_assay( project_ids = await self.sampt.get_project_ids_for_sample_ids( [assay.sample_id] ) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False + + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles ) seq_id = await self.seqt.insert_assay( @@ -109,12 +103,12 @@ async def upsert_assay( ) assay.id = seq_id else: - if check_project_id: - # can check the project id of the assay we're updating - project_ids = await self.seqt.get_projects_by_assay_ids([assay.id]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False - ) + # can check the project id of the assay we're updating + project_ids = await self.seqt.get_projects_by_assay_ids([assay.id]) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) + # Otherwise update await self.seqt.update_assay( assay.id, @@ -129,27 +123,24 @@ async def upsert_assay( async def upsert_assays( self, assays: list[AssayUpsertInternal], - check_project_ids: bool = True, open_transaction=True, ) -> list[AssayUpsertInternal]: """Upsert multiple sequences to the given sample (sid)""" - if check_project_ids: - sample_ids = set(s.sample_id for s in assays) - st = SampleTable(self.connection) - project_ids = await st.get_project_ids_for_sample_ids(list(sample_ids)) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False - ) + sample_ids = set(s.sample_id for s in assays) + st = SampleTable(self.connection) + project_ids = await st.get_project_ids_for_sample_ids(list(sample_ids)) + + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) with_function = ( self.connection.connection.transaction if open_transaction else NoOpAenter ) async with with_function(): for a in assays: - await self.upsert_assay( - a, check_project_id=False, open_transaction=False - ) + await self.upsert_assay(a, open_transaction=False) return assays diff --git a/db/python/layers/audit_log.py b/db/python/layers/audit_log.py index 1cc351ecc..b3f86ad81 100644 --- a/db/python/layers/audit_log.py +++ b/db/python/layers/audit_log.py @@ -1,6 +1,7 @@ from db.python.layers.base import BaseLayer, Connection from db.python.tables.audit_log import AuditLogTable from models.models.audit_log import AuditLogId, AuditLogInternal +from models.models.project import ReadAccessRoles class AuditLogLayer(BaseLayer): @@ -11,19 +12,16 @@ def __init__(self, connection: Connection): self.alayer: AuditLogTable = AuditLogTable(connection) # GET - async def get_for_ids( - self, ids: list[AuditLogId], check_project_id=True - ) -> list[AuditLogInternal]: + async def get_for_ids(self, ids: list[AuditLogId]) -> list[AuditLogInternal]: """Query for samples""" if not ids: return [] logs = await self.alayer.get_audit_logs_for_ids(ids) - if check_project_id: - projects = {log.auth_project for log in logs} - await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=True - ) + projects = {log.auth_project for log in logs} + self.connection.check_access_to_projects_for_ids( + project_ids=projects, allowed_roles=ReadAccessRoles + ) return logs diff --git a/db/python/layers/base.py b/db/python/layers/base.py index 28c7575a7..edb893c52 100644 --- a/db/python/layers/base.py +++ b/db/python/layers/base.py @@ -1,4 +1,4 @@ -from db.python.tables.project import Connection, ProjectPermissionsTable +from db.python.connect import Connection class BaseLayer: @@ -6,7 +6,6 @@ class BaseLayer: def __init__(self, connection: Connection): self.connection = connection - self.ptable = ProjectPermissionsTable(connection) @property def author(self): diff --git a/db/python/layers/cohort.py b/db/python/layers/cohort.py index c69603d2e..1dfd9e327 100644 --- a/db/python/layers/cohort.py +++ b/db/python/layers/cohort.py @@ -3,7 +3,6 @@ from db.python.layers.base import BaseLayer from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.cohort import CohortFilter, CohortTable, CohortTemplateFilter -from db.python.tables.project import ProjectId, ProjectPermissionsTable from db.python.tables.sample import SampleFilter, SampleTable from db.python.tables.sequencing_group import ( SequencingGroupFilter, @@ -16,6 +15,7 @@ CohortTemplateInternal, NewCohortInternal, ) +from models.models.project import ProjectId, ReadAccessRoles logger = get_logger() @@ -69,29 +69,23 @@ def __init__(self, connection: Connection): self.sampt = SampleTable(connection) self.ct = CohortTable(connection) - self.pt = ProjectPermissionsTable(connection) self.sgt = SequencingGroupTable(connection) self.sglayer = SequencingGroupLayer(self.connection) - async def query( - self, filter_: CohortFilter, check_project_ids: bool = True - ) -> list[CohortInternal]: + async def query(self, filter_: CohortFilter) -> list[CohortInternal]: """Query Cohorts""" cohorts, project_ids = await self.ct.query(filter_) if not cohorts: return [] - if check_project_ids: - await self.pt.get_and_check_access_to_projects_for_ids( - user=self.connection.author, - project_ids=list(project_ids), - readonly=True, - ) + self.connection.check_access_to_projects_for_ids( + project_ids=list(project_ids), allowed_roles=ReadAccessRoles + ) return cohorts async def query_cohort_templates( - self, filter_: CohortTemplateFilter, check_project_ids: bool = True + self, filter_: CohortTemplateFilter ) -> list[CohortTemplateInternal]: """Query CohortTemplates""" project_ids, cohort_templates = await self.ct.query_cohort_templates(filter_) @@ -99,12 +93,10 @@ async def query_cohort_templates( if not cohort_templates: return [] - if check_project_ids: - await self.pt.get_and_check_access_to_projects_for_ids( - user=self.connection.author, - project_ids=list(project_ids), - readonly=True, - ) + self.connection.check_access_to_projects_for_ids( + project_ids=list(project_ids), allowed_roles=ReadAccessRoles + ) + return cohort_templates async def get_template_by_cohort_id(self, cohort_id: int) -> CohortTemplateInternal: diff --git a/db/python/layers/family.py b/db/python/layers/family.py index f0e9820e8..fcaa218a0 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -13,9 +13,10 @@ from db.python.tables.participant import ParticipantTable from db.python.tables.sample import SampleTable from db.python.utils import NotFoundError -from models.models import PRIMARY_EXTERNAL_ORG, ProjectId +from models.models import PRIMARY_EXTERNAL_ORG from models.models.family import FamilyInternal, PedRow, PedRowInternal from models.models.participant import ParticipantUpsertInternal +from models.models.project import ProjectId, ReadAccessRoles class FamilyLayer(BaseLayer): @@ -37,9 +38,7 @@ async def create_family( coded_phenotype=coded_phenotype, ) - async def get_family_by_internal_id( - self, family_id: int, check_project_id: bool = True - ) -> FamilyInternal: + async def get_family_by_internal_id(self, family_id: int) -> FamilyInternal: """Get family by internal ID""" projects, families = await self.ftable.query( FamilyFilter(id=GenericFilter(eq=family_id)) @@ -47,10 +46,10 @@ async def get_family_by_internal_id( if not families: raise NotFoundError(f'Family with ID {family_id} not found') family = families[0] - if check_project_id: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return family @@ -61,7 +60,7 @@ async def get_family_by_external_id( families = await self.ftable.query( FamilyFilter( external_id=GenericFilter(eq=external_id), - project=GenericFilter(eq=project or self.connection.project), + project=GenericFilter(eq=project or self.connection.project_id), ) ) if not families: @@ -72,7 +71,6 @@ async def get_family_by_external_id( async def query( self, filter_: FamilyFilter, - check_project_ids: bool = True, ) -> list[FamilyInternal]: """Get all families for a project""" @@ -80,10 +78,9 @@ async def query( projects, families = await self.ftable.query(filter_) - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return families @@ -91,7 +88,6 @@ async def get_families_by_ids( self, family_ids: list[int], check_missing: bool = True, - check_project_ids: bool = True, ) -> list[FamilyInternal]: """Get families by internal IDs""" projects, families = await self.ftable.query( @@ -100,10 +96,9 @@ async def get_families_by_ids( if not families: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) if check_missing and len(family_ids) != len(families): missing_ids = set(family_ids) - set(f.id for f in families) @@ -112,7 +107,7 @@ async def get_families_by_ids( return families async def get_families_by_participants( - self, participant_ids: list[int], check_project_ids: bool = True + self, participant_ids: list[int] ) -> dict[int, list[FamilyInternal]]: """ Get families keyed by participant_ids, this will duplicate families @@ -123,10 +118,9 @@ async def get_families_by_participants( if not participant_map: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return participant_map @@ -136,14 +130,13 @@ async def update_family( external_id: str = None, description: str = None, coded_phenotype: str = None, - check_project_ids: bool = True, ) -> bool: """Update fields on some family""" - if check_project_ids: - project_ids = await self.ftable.get_projects_by_family_ids([id_]) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False - ) + project_ids = await self.ftable.get_projects_by_family_ids([id_]) + + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=ReadAccessRoles + ) return await self.ftable.update_family( id_=id_, @@ -214,9 +207,7 @@ async def get_pedigree( return mapped_rows - async def get_participant_family_map( - self, participant_ids: list[int], check_project_ids=False - ): + async def get_participant_family_map(self, participant_ids: list[int]): """Get participant family map""" fptable = FamilyParticipantTable(self.connection) @@ -224,8 +215,9 @@ async def get_participant_family_map( participant_ids=participant_ids ) - if check_project_ids: - raise NotImplementedError(f'Must check specified projects: {projects}') + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return family_map @@ -277,7 +269,7 @@ async def import_pedigree( external_family_id_map = await self.ftable.get_id_map_by_external_ids( list(external_family_ids), - project=self.connection.project, + project=self.connection.project_id, allow_missing=True, ) missing_external_family_ids = [ @@ -285,7 +277,7 @@ async def import_pedigree( ] external_participant_ids_map = await participant_table.get_id_map_by_external_ids( list(external_participant_ids), - project=self.connection.project, + project=self.connection.project_id, # Allow missing participants if we're creating them allow_missing=create_missing_participants, ) @@ -419,7 +411,7 @@ def select_columns( return True async def get_family_participants_by_family_ids( - self, family_ids: list[int], check_project_ids: bool = True + self, family_ids: list[int] ) -> dict[int, list[PedRowInternal]]: """Get family participants for family IDs""" projects, fps = await self.fptable.query( @@ -429,15 +421,14 @@ async def get_family_participants_by_family_ids( if not fps: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return group_by(fps, lambda r: r.family_id) async def get_family_participants_for_participants( - self, participant_ids: list[int], check_project_ids: bool = True + self, participant_ids: list[int] ) -> list[PedRowInternal]: """Get family participants for participant IDs""" projects, fps = await self.fptable.query( @@ -447,9 +438,8 @@ async def get_family_participants_for_participants( if not fps: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return fps diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index fe5658ec7..b2754eafb 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Any +from db.python.connect import Connection from db.python.filters import GenericFilter from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer @@ -16,9 +17,10 @@ from db.python.tables.participant_phenotype import ParticipantPhenotypeTable from db.python.tables.sample import SampleTable from db.python.utils import NoOpAenter, NotFoundError, split_generic_terms -from models.models import PRIMARY_EXTERNAL_ORG, ProjectId +from models.models import PRIMARY_EXTERNAL_ORG from models.models.family import PedRowInternal from models.models.participant import ParticipantInternal, ParticipantUpsertInternal +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles HPO_REGEX_MATCHER = re.compile(r'HP\:\d+$') @@ -234,7 +236,7 @@ def process_hpo_term(term): class ParticipantLayer(BaseLayer): """Layer for more complex sample logic""" - def __init__(self, connection): + def __init__(self, connection: Connection): super().__init__(connection) self.pttable = ParticipantTable(connection=connection) @@ -243,7 +245,6 @@ async def query( filter_: ParticipantFilter, limit: int | None = None, skip: int | None = None, - check_project_ids: bool = True, ) -> list[ParticipantInternal]: """ Query participants from the database, heavy lifting done by the filter @@ -255,10 +256,9 @@ async def query( if not participants: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return participants @@ -283,7 +283,6 @@ async def query_count( async def get_participants_by_ids( self, pids: list[int], - check_project_ids: bool = True, allow_missing: bool = False, ) -> list[ParticipantInternal]: """ @@ -295,10 +294,9 @@ async def get_participants_by_ids( if not participants: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) if not allow_missing and len(participants) != len(pids): # participants are missing @@ -336,7 +334,7 @@ async def fill_in_missing_participants(self): # { external_id: internal_id } samples_with_no_pid = ( await sample_table.get_samples_with_missing_participants_by_internal_id( - project=self.connection.project + project=self.connection.project_id ) ) external_sample_map_with_no_pid = { @@ -351,7 +349,7 @@ async def fill_in_missing_participants(self): unlinked_participants = await self.get_id_map_by_external_ids( list(external_sample_map_with_no_pid.keys()), - project=self.connection.project, + project=self.connection.project_id, allow_missing=True, ) @@ -442,10 +440,10 @@ async def generic_individual_metadata_importer( allow_missing_participants = ( extra_participants_method != ExtraParticipantImporterHandler.FAIL ) - assert self.connection.project + assert self.connection.project_id external_pid_map = await self.get_id_map_by_external_ids( list(external_participant_ids), - project=self.connection.project, + project=self.connection.project_id, allow_missing=allow_missing_participants, ) if extra_participants_method == ExtraParticipantImporterHandler.ADD: @@ -459,7 +457,7 @@ async def generic_individual_metadata_importer( reported_gender=None, karyotype=None, meta=None, - project=self.connection.project, + project=self.connection.project_id, ) elif extra_participants_method == ExtraParticipantImporterHandler.IGNORE: rows = [ @@ -492,7 +490,7 @@ async def generic_individual_metadata_importer( fmap_by_internal = await ftable.get_id_map_by_internal_ids(list(fids)) fmap_from_external = await ftable.get_id_map_by_external_ids( list(external_family_ids), - project=self.connection.project, + project=self.connection.project_id, allow_missing=True, ) fmap_by_external = { @@ -565,7 +563,7 @@ async def generic_individual_metadata_importer( return True async def get_participants_by_families( - self, family_ids: list[int], check_project_ids: bool = True + self, family_ids: list[int] ) -> dict[int, list[ParticipantInternal]]: """Get participants, keyed by family ID""" projects, family_map = await self.pttable.get_participants_by_families( @@ -574,10 +572,9 @@ async def get_participants_by_families( if not family_map: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return family_map @@ -625,7 +622,6 @@ async def upsert_participant( self, participant: ParticipantUpsertInternal, project: ProjectId = None, - check_project_id: bool = True, open_transaction=True, ) -> ParticipantUpsertInternal: """Create a single participant""" @@ -637,16 +633,13 @@ async def upsert_participant( async with with_function(): if participant.id: - if check_project_id: - project_ids = ( - await self.pttable.get_project_ids_for_participant_ids( - [participant.id] - ) - ) + project_ids = await self.pttable.get_project_ids_for_participant_ids( + [participant.id] + ) - await self.ptable.check_access_to_project_ids( - self.connection.author, project_ids, readonly=False - ) + self.connection.check_access_to_projects_for_ids( + project_ids, allowed_roles=FullWriteAccessRoles + ) await self.pttable.update_participant( participant_id=participant.id, external_ids=participant.external_ids, @@ -674,7 +667,6 @@ async def upsert_participant( await slayer.upsert_samples( participant.samples, project=project, - check_project_id=False, open_transaction=False, ) @@ -700,16 +692,15 @@ async def upsert_participants( return participants async def update_many_participant_external_ids( - self, internal_to_external_id: dict[int, str], check_project_ids=True + self, internal_to_external_id: dict[int, str] ): """Update many participant external ids""" - if check_project_ids: - projects = await self.pttable.get_project_ids_for_participant_ids( - list(internal_to_external_id.keys()) - ) - await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=False - ) + projects = await self.pttable.get_project_ids_for_participant_ids( + list(internal_to_external_id.keys()) + ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) return await self.pttable.update_many_participant_external_ids( internal_to_external_id @@ -749,12 +740,12 @@ async def get_seqr_individual_template( internal_to_external_fid_map = {} if external_participant_ids or internal_participant_ids: - assert self.connection.project + assert self.connection.project_id pids = set(internal_participant_ids or []) if external_participant_ids: pid_map = await self.get_id_map_by_external_ids( external_participant_ids, - project=self.connection.project, + project=self.connection.project_id, allow_missing=False, ) pids |= set(pid_map.values()) @@ -822,7 +813,7 @@ async def get_seqr_individual_template( } async def get_family_participant_data( - self, family_id: int, participant_id: int, check_project_ids: bool = True + self, family_id: int, participant_id: int ) -> PedRowInternal: """Gets the family_participant row for a specific participant""" fptable = FamilyParticipantTable(self.connection) @@ -838,10 +829,10 @@ async def get_family_participant_data( f'Family participant row (family_id: {family_id}, ' f'participant_id: {participant_id}) not found' ) - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return rows[0] @@ -991,10 +982,9 @@ async def check_project_access_for_participants_families( ) ftable = FamilyTable(self.connection) fprojects = await ftable.get_projects_by_family_ids(family_ids=family_ids) - return await self.ptable.check_access_to_project_ids( - self.connection.author, - list(pprojects | fprojects), - readonly=True, + + return self.connection.check_access_to_projects_for_ids( + list(pprojects | fprojects), allowed_roles=ReadAccessRoles ) async def update_participant_family( diff --git a/db/python/layers/sample.py b/db/python/layers/sample.py index 3ed7179da..10bf3d77e 100644 --- a/db/python/layers/sample.py +++ b/db/python/layers/sample.py @@ -6,10 +6,9 @@ from db.python.layers.assay import AssayLayer from db.python.layers.base import BaseLayer, Connection from db.python.layers.sequencing_group import SequencingGroupLayer -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter, SampleTable from db.python.utils import NoOpAenter, NotFoundError -from models.models.project import ProjectId +from models.models.project import FullWriteAccessRoles, ProjectId, ReadAccessRoles from models.models.sample import SampleInternal, SampleUpsertInternal from models.utils.sample_id_format import sample_id_format_list @@ -20,37 +19,33 @@ class SampleLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) self.st: SampleTable = SampleTable(connection) - self.pt = ProjectPermissionsTable(connection) self.connection = connection # GETS - async def get_by_id(self, sample_id: int, check_project_id=True) -> SampleInternal: + async def get_by_id(self, sample_id: int) -> SampleInternal: """Get sample by internal sample id""" project, sample = await self.st.get_sample_by_id(sample_id) - if check_project_id: - await self.pt.check_access_to_project_ids( - self.connection.author, [project], readonly=True - ) + + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return sample - async def query( - self, filter_: SampleFilter, check_project_ids: bool = True - ) -> list[SampleInternal]: + async def query(self, filter_: SampleFilter) -> list[SampleInternal]: """Query samples""" projects, samples = await self.st.query(filter_) if not samples: return samples - if check_project_ids: - await self.pt.check_access_to_project_ids( - self.connection.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return samples async def get_samples_by_participants( - self, participant_ids: list[int], check_project_ids: bool = True + self, participant_ids: list[int] ) -> dict[int, list[SampleInternal]]: """Get map of samples by participants""" @@ -63,10 +58,9 @@ async def get_samples_by_participants( if not samples: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) grouped_samples = group_by(samples, lambda s: s.participant_id) @@ -76,15 +70,12 @@ async def get_project_ids_for_sample_ids(self, sample_ids: list[int]) -> set[int """Return the projects associated with the sample ids""" return await self.st.get_project_ids_for_sample_ids(sample_ids) - async def get_sample_by_id( - self, sample_id: int, check_project_id=True - ) -> SampleInternal: + async def get_sample_by_id(self, sample_id: int) -> SampleInternal: """Get sample by ID""" project, sample = await self.st.get_sample_by_id(sample_id) - if check_project_id: - await self.pt.check_access_to_project_ids( - self.author, [project], readonly=True - ) + self.connection.check_access_to_projects_for_ids( + [project], allowed_roles=ReadAccessRoles + ) return sample @@ -104,7 +95,7 @@ async def get_sample_id_map_by_external_ids( ) -> dict[str, int]: """Get map of samples {(any) external_id: internal_id}""" external_ids_set = set(external_ids) - _project = project or self.connection.project + _project = project or self.connection.project_id assert _project sample_id_map = await self.st.get_sample_id_map_by_external_ids( external_ids=list(external_ids_set), project=_project @@ -121,7 +112,7 @@ async def get_sample_id_map_by_external_ids( ) async def get_internal_to_external_sample_id_map( - self, sample_ids: list[int], check_project_ids=True, allow_missing=False + self, sample_ids: list[int], allow_missing=False ) -> dict[int, str]: """Get map of internal sample id to external id""" @@ -146,10 +137,9 @@ async def get_internal_to_external_sample_id_map( if not sample_id_map: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return sample_id_map @@ -166,18 +156,17 @@ async def get_samples_by( participant_ids: list[int] | None = None, project_ids=None, active=True, - check_project_ids=True, ) -> list[SampleInternal]: """Get samples by some criteria""" if not sample_ids and not project_ids: raise ValueError('Must specify one of "project_ids" or "sample_ids"') - if sample_ids and check_project_ids: + if sample_ids: # project_ids were already checked when transformed to ints, # so no else required pjcts = await self.st.get_project_ids_for_sample_ids(sample_ids) - await self.ptable.check_access_to_project_ids( - self.author, pjcts, readonly=True + self.connection.check_access_to_projects_for_ids( + pjcts, allowed_roles=ReadAccessRoles ) _returned_project_ids, samples = await self.st.query( @@ -192,10 +181,9 @@ async def get_samples_by( if not samples: return [] - if not project_ids and check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, _returned_project_ids, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + _returned_project_ids, allowed_roles=ReadAccessRoles + ) return samples @@ -211,7 +199,9 @@ async def get_samples_create_date( ) -> dict[int, datetime.date]: """Get a map of {internal_sample_id: date_created} for list of sample_ids""" pjcts = await self.st.get_project_ids_for_sample_ids(sample_ids) - await self.pt.check_access_to_project_ids(self.author, pjcts, readonly=True) + self.connection.check_access_to_projects_for_ids( + pjcts, allowed_roles=ReadAccessRoles + ) return await self.st.get_samples_create_date(sample_ids) # CREATE / UPDATES @@ -281,7 +271,6 @@ async def upsert_samples( samples: list[SampleUpsertInternal], open_transaction: bool = True, project: ProjectId = None, - check_project_id=True, ) -> list[SampleUpsertInternal]: """Batch upsert a list of samples with sequences""" seqglayer: SequencingGroupLayer = SequencingGroupLayer(self.connection) @@ -290,13 +279,12 @@ async def upsert_samples( self.connection.connection.transaction if open_transaction else NoOpAenter ) - if check_project_id: - sids = [s.id for s in samples if s.id] - if sids: - pjcts = await self.st.get_project_ids_for_sample_ids(sids) - await self.ptable.check_access_to_project_ids( - self.author, pjcts, readonly=False - ) + sids = [s.id for s in samples if s.id] + if sids: + pjcts = await self.st.get_project_ids_for_sample_ids(sids) + self.connection.check_access_to_projects_for_ids( + pjcts, allowed_roles=ReadAccessRoles + ) async with with_function(): # Create or update samples @@ -407,14 +395,12 @@ async def merge_samples( self, id_keep: int, id_merge: int, - check_project_id=True, ): """Merge two samples into one another""" - if check_project_id: - projects = await self.st.get_project_ids_for_sample_ids([id_keep, id_merge]) - await self.ptable.check_access_to_project_ids( - user=self.author, project_ids=projects, readonly=False - ) + projects = await self.st.get_project_ids_for_sample_ids([id_keep, id_merge]) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) return await self.st.merge_samples( id_keep=id_keep, @@ -422,7 +408,7 @@ async def merge_samples( ) async def update_many_participant_ids( - self, ids: list[int], participant_ids: list[int], check_sample_ids=True + self, ids: list[int], participant_ids: list[int] ) -> bool: """ Update participant IDs for many samples @@ -432,27 +418,24 @@ async def update_many_participant_ids( raise ValueError( f'Number of sampleIDs ({len(ids)}) and ParticipantIds ({len(participant_ids)}) did not match' ) - if check_sample_ids: - project_ids = await self.st.get_project_ids_for_sample_ids(ids) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=False - ) + + projects = await self.st.get_project_ids_for_sample_ids(ids) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) await self.st.update_many_participant_ids( ids=ids, participant_ids=participant_ids ) return True - async def get_history_of_sample( - self, id_: int, check_sample_ids: bool = True - ) -> list[SampleInternal]: + async def get_history_of_sample(self, id_: int) -> list[SampleInternal]: """Get the full history of a sample""" rows = await self.st.get_history_of_sample(id_) - if check_sample_ids: - project_ids = set(r.project for r in rows) - await self.ptable.check_access_to_project_ids( - self.author, project_ids, readonly=True - ) + projects = set(r.project for r in rows) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=FullWriteAccessRoles + ) return rows diff --git a/db/python/layers/search.py b/db/python/layers/search.py index 0115b8bfa..dbba250ef 100644 --- a/db/python/layers/search.py +++ b/db/python/layers/search.py @@ -4,7 +4,6 @@ from db.python.layers.base import BaseLayer, Connection from db.python.tables.family import FamilyTable from db.python.tables.participant import ParticipantTable -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleTable from db.python.tables.sequencing_group import SequencingGroupTable from db.python.utils import NotFoundError @@ -28,7 +27,6 @@ class SearchLayer(BaseLayer): def __init__(self, connection: Connection): super().__init__(connection) - self.pt = ProjectPermissionsTable(connection) self.connection = connection @staticmethod diff --git a/db/python/layers/seqr.py b/db/python/layers/seqr.py index 5e6ab8f92..ae2c0cd7b 100644 --- a/db/python/layers/seqr.py +++ b/db/python/layers/seqr.py @@ -134,14 +134,13 @@ async def sync_dataset( raise ValueError('Seqr synchronisation is not configured in metamist') token = self.generate_seqr_auth_token() - project = await self.ptable.get_and_check_access_to_project_for_id( - self.connection.author, - project_id=self.connection.project, - readonly=True, - ) + project = self.connection.project + assert project - seqr_guid = project.meta.get( - self.get_meta_key_from_sequencing_type(sequencing_type) + seqr_guid = ( + project.meta.get(self.get_meta_key_from_sequencing_type(sequencing_type)) + if project.meta + else None ) if not seqr_guid: @@ -430,8 +429,9 @@ async def update_es_index( sequencing_group_ids: set[int], ) -> list[str]: """Update seqr samples for latest elastic-search index""" + assert self.connection.project_id eid_to_sgid_rows = await self.player.get_external_participant_id_to_internal_sequencing_group_id_map( - self.connection.project, sequencing_type=sequencing_type + self.connection.project_id, sequencing_type=sequencing_type ) # format sample ID for transport @@ -464,7 +464,7 @@ async def update_es_index( alayer = AnalysisLayer(connection=self.connection) es_index_analyses = await alayer.query( AnalysisFilter( - project=GenericFilter(eq=self.connection.project), + project=GenericFilter(eq=self.connection.project_id), type=GenericFilter(eq='es-index'), status=GenericFilter(eq=AnalysisStatus.COMPLETED), meta={ @@ -542,8 +542,9 @@ async def sync_cram_map( alayer = AnalysisLayer(self.connection) + assert self.connection.project_id reads_map = await alayer.get_sample_cram_path_map_for_seqr( - project=self.connection.project, + project=self.connection.project_id, sequencing_types=[sequencing_type], participant_ids=participant_ids, ) @@ -619,9 +620,9 @@ async def _make_update_igv_call(update): async def _get_pedigree_from_sm(self, family_ids: set[int]) -> list[dict] | None: """Call get_pedigree and return formatted string with header""" - + assert self.connection.project_id ped_rows = await self.flayer.get_pedigree( - self.connection.project, + self.connection.project_id, family_ids=list(family_ids), replace_with_family_external_ids=True, replace_with_participant_external_ids=True, @@ -678,8 +679,9 @@ async def get_individual_meta_objs_for_seqr( self, participant_ids: list[int] ) -> list[dict] | None: """Get formatted list of dictionaries for syncing individual meta to seqr""" + assert self.connection.project_id individual_metadata_resp = await self.player.get_seqr_individual_template( - self.connection.project, internal_participant_ids=participant_ids + self.connection.project_id, internal_participant_ids=participant_ids ) json_rows: list[dict] = individual_metadata_resp['rows'] diff --git a/db/python/layers/sequencing_group.py b/db/python/layers/sequencing_group.py index 211dd0ce3..54ce5c864 100644 --- a/db/python/layers/sequencing_group.py +++ b/db/python/layers/sequencing_group.py @@ -11,7 +11,7 @@ SequencingGroupTable, ) from db.python.utils import NotFoundError -from models.models.project import ProjectId +from models.models.project import ProjectId, ReadAccessRoles from models.models.sequencing_group import ( SequencingGroupInternal, SequencingGroupInternalId, @@ -29,19 +29,17 @@ def __init__(self, connection: Connection): self.sampt: SampleTable = SampleTable(connection) async def get_sequencing_group_by_id( - self, sequencing_group_id: int, check_project_id: bool = True + self, sequencing_group_id: int ) -> SequencingGroupInternal: """ Get sequencing group by internal ID """ - groups = await self.get_sequencing_groups_by_ids( - [sequencing_group_id], check_project_ids=check_project_id - ) + groups = await self.get_sequencing_groups_by_ids([sequencing_group_id]) return groups[0] async def get_sequencing_groups_by_ids( - self, sequencing_group_ids: list[int], check_project_ids: bool = True + self, sequencing_group_ids: list[int] ) -> list[SequencingGroupInternal]: """ Get sequence groups by internal IDs @@ -56,10 +54,9 @@ async def get_sequencing_groups_by_ids( if not groups: return [] - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) if len(groups) != len(sequencing_group_ids): missing_ids = set(sequencing_group_ids) - set(sg.id for sg in groups) @@ -71,7 +68,7 @@ async def get_sequencing_groups_by_ids( return groups async def get_sequencing_groups_by_analysis_ids( - self, analysis_ids: list[int], check_project_ids: bool = True + self, analysis_ids: list[int] ) -> dict[int, list[SequencingGroupInternal]]: """ Get sequencing groups by analysis IDs @@ -86,17 +83,15 @@ async def get_sequencing_groups_by_analysis_ids( if not groups: return groups - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return groups async def query( self, filter_: SequencingGroupFilter, - check_project_ids: bool = True, ) -> list[SequencingGroupInternal]: """ Query sequencing groups @@ -105,10 +100,9 @@ async def query( if not sequencing_groups: return [] - if check_project_ids and not (filter_.project and filter_.project.in_): - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return sequencing_groups @@ -141,7 +135,7 @@ async def get_all_sequencing_group_ids_by_sample_ids_by_type( return await self.seqgt.get_all_sequencing_group_ids_by_sample_ids_by_type() async def get_participant_ids_sequencing_group_ids_for_sequencing_type( - self, sequencing_type: str, check_project_ids: bool = True + self, sequencing_type: str ) -> dict[int, list[int]]: """ Get list of partiicpant IDs for a specific sequence type, @@ -156,10 +150,9 @@ async def get_participant_ids_sequencing_group_ids_for_sequencing_type( if not pids: return {} - if check_project_ids: - await self.ptable.check_access_to_project_ids( - self.author, projects, readonly=True - ) + self.connection.check_access_to_projects_for_ids( + projects, allowed_roles=ReadAccessRoles + ) return pids diff --git a/db/python/layers/web.py b/db/python/layers/web.py index 6498f0b6a..e217b37a0 100644 --- a/db/python/layers/web.py +++ b/db/python/layers/web.py @@ -17,7 +17,6 @@ from db.python.tables.assay import AssayFilter, AssayTable from db.python.tables.base import DbBase from db.python.tables.participant import ParticipantFilter -from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sequencing_group import SequencingGroupTable from models.models import ( AssayInternal, @@ -69,12 +68,12 @@ class WebDb(DbBase): async def get_total_number_of_samples(self): """Get total number of active samples within a project""" _query = 'SELECT COUNT(*) FROM sample WHERE project = :project AND active' - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) async def get_total_number_of_participants(self): """Get total number of participants within a project""" _query = 'SELECT COUNT(*) FROM participant WHERE project = :project' - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) async def get_total_number_of_sequencing_groups(self): """Get total number of sequencing groups within a project""" @@ -83,7 +82,7 @@ async def get_total_number_of_sequencing_groups(self): FROM sequencing_group sg INNER JOIN sample s ON s.id = sg.sample_id WHERE project = :project AND NOT sg.archived""" - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) async def get_total_number_of_assays(self): """Get total number of sequences within a project""" @@ -92,7 +91,7 @@ async def get_total_number_of_assays(self): FROM assay sq INNER JOIN sample s ON s.id = sq.sample_id WHERE s.project = :project""" - return await self.connection.fetch_val(_query, {'project': self.project}) + return await self.connection.fetch_val(_query, {'project': self.project_id}) def get_seqr_links_from_project(self, project: WebProject) -> dict[str, str]: """ @@ -119,15 +118,11 @@ async def get_project_summary( :param token: for PAGING :param limit: Number of SAMPLEs to return, not including nested sequences """ - if not self.project: - raise ValueError('Project not provided') - ptable = ProjectPermissionsTable(self._connection) - project_db = await ptable.get_and_check_access_to_project_for_id( - self.author, self.project, readonly=True - ) + project_db = self.project + if not project_db: - raise ValueError(f'Project {self.project} not found') + raise ValueError('Project not provided') project = WebProject( id=project_db.id, @@ -156,10 +151,12 @@ async def get_project_summary( self.get_total_number_of_participants(), self.get_total_number_of_sequencing_groups(), self.get_total_number_of_assays(), - atable.get_number_of_crams_by_sequencing_type(project=self.project), - sgtable.get_type_numbers_for_project(project=self.project), - seqtable.get_assay_type_numbers_by_batch_for_project(project=self.project), - atable.get_seqr_stats_by_sequencing_type(project=self.project), + atable.get_number_of_crams_by_sequencing_type(project=self.project_id), + sgtable.get_type_numbers_for_project(project=self.project_id), + seqtable.get_assay_type_numbers_by_batch_for_project( + project=self.project_id + ), + atable.get_seqr_stats_by_sequencing_type(project=self.project_id), SeqrLayer(self._connection).get_synchronisable_types(project_db), ) diff --git a/db/python/tables/analysis.py b/db/python/tables/analysis.py index 16765c678..ba04e2d7a 100644 --- a/db/python/tables/analysis.py +++ b/db/python/tables/analysis.py @@ -72,7 +72,7 @@ async def create_analysis( ('meta', to_db_json(meta or {})), ('output', output), ('audit_log_id', await self.audit_log_id()), - ('project', project or self.project), + ('project', project or self.project_id), ('active', active if active is not None else True), ] @@ -333,7 +333,7 @@ async def get_all_sequencing_group_ids_without_analysis_type( rows = await self.connection.fetch_all( _query, - {'analysis_type': analysis_type, 'project': project or self.project}, + {'analysis_type': analysis_type, 'project': project or self.project_id}, ) return [row[0] for row in rows] diff --git a/db/python/tables/analysis_runner.py b/db/python/tables/analysis_runner.py index 2e93e0c74..0bc3bfa89 100644 --- a/db/python/tables/analysis_runner.py +++ b/db/python/tables/analysis_runner.py @@ -84,7 +84,7 @@ async def insert_analysis_runner_entry( 'meta': to_db_json(analysis_runner.meta), 'output_path': analysis_runner.output_path, 'audit_log_id': await self.audit_log_id(), - 'project': self.project, + 'project': self.project_id, } await self.connection.execute(_query, values) diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py index 57e701112..5c3438e7b 100644 --- a/db/python/tables/assay.py +++ b/db/python/tables/assay.py @@ -139,12 +139,12 @@ async def get_assay_by_external_id( self, external_sequence_id: str, project: ProjectId | None = None ) -> AssayInternal: """Get assay by EXTERNAL ID""" - if not (project or self.project): + if not (project or self.project_id): raise ValueError('Getting assay by external ID requires a project') f = AssayFilter( external_id=GenericFilter(eq=external_sequence_id), - project=GenericFilter(eq=project or self.project), + project=GenericFilter(eq=project or self.project_id), ) _, assays = await self.query(f) @@ -328,7 +328,7 @@ async def insert_assay( ) if external_ids: - _project = project or self.project + _project = project or self.project_id if not _project: raise ValueError( 'When inserting an external identifier for a sequence, a ' @@ -343,7 +343,7 @@ async def insert_assay( audit_log_id = await self.audit_log_id() eid_values = [ { - 'project': project or self.project, + 'project': project or self.project_id, 'assay_id': id_of_new_assay, 'external_id': eid, 'name': name.lower(), @@ -422,7 +422,7 @@ async def update_assay( await self.connection.execute(_query, fields) if external_ids: - _project = project or self.project + _project = project or self.project_id if not _project: raise ValueError( 'When inserting or updating an external identifier for an ' diff --git a/db/python/tables/base.py b/db/python/tables/base.py index c069bf60a..2da7b5e12 100644 --- a/db/python/tables/base.py +++ b/db/python/tables/base.py @@ -25,6 +25,7 @@ def __init__(self, connection: Connection): self.connection: databases.Database = connection.connection self.author = connection.author self.project = connection.project + self.project_id = connection.project_id if self.author is None: raise InternalError(f'Must provide author to {self.__class__.__name__}') diff --git a/db/python/tables/cohort.py b/db/python/tables/cohort.py index bd821a504..5a4ced33b 100644 --- a/db/python/tables/cohort.py +++ b/db/python/tables/cohort.py @@ -4,7 +4,6 @@ from db.python.filters import GenericFilter, GenericFilterModel from db.python.tables.base import DbBase -from db.python.tables.project import ProjectId from db.python.utils import NotFoundError, to_db_json from models.models.cohort import ( CohortCriteriaInternal, @@ -12,6 +11,7 @@ CohortTemplateInternal, NewCohortInternal, ) +from models.models.project import ProjectId @dataclasses.dataclass(kw_only=True) diff --git a/db/python/tables/family.py b/db/python/tables/family.py index be3d94f08..b518ac6ee 100644 --- a/db/python/tables/family.py +++ b/db/python/tables/family.py @@ -204,7 +204,7 @@ async def create_family( 'description': description, 'coded_phenotype': coded_phenotype, 'audit_log_id': await self.audit_log_id(), - 'project': project or self.project, + 'project': project or self.project_id, } keys = list(updater.keys()) str_keys = ', '.join(keys) @@ -233,7 +233,7 @@ async def insert_or_update_multiple_families( 'description': descr, 'coded_phenotype': cph, 'audit_log_id': await self.audit_log_id(), - 'project': project or self.project, + 'project': project or self.project_id, } for eid, descr, cph in zip(external_ids, descriptions, coded_phenotypes) ] @@ -267,7 +267,7 @@ async def get_id_map_by_external_ids( _query = 'SELECT external_id, id FROM family WHERE external_id in :external_ids AND project = :project' results = await self.connection.fetch_all( - _query, {'external_ids': family_ids, 'project': project or self.project} + _query, {'external_ids': family_ids, 'project': project or self.project_id} ) id_map = {r['external_id']: r['id'] for r in results} diff --git a/db/python/tables/participant.py b/db/python/tables/participant.py index ee8bacd28..df0c5b35f 100644 --- a/db/python/tables/participant.py +++ b/db/python/tables/participant.py @@ -283,7 +283,7 @@ async def create_participant( """ Create a new sample, and add it to database """ - if not (project or self.project): + if not (project or self.project_id): raise ValueError('Must provide project to create participant') if not external_ids or external_ids.get(PRIMARY_EXTERNAL_ORG, None) is None: @@ -307,7 +307,7 @@ async def create_participant( 'karyotype': karyotype, 'meta': to_db_json(meta or {}), 'audit_log_id': audit_log_id, - 'project': project or self.project, + 'project': project or self.project_id, }, ) @@ -317,7 +317,7 @@ async def create_participant( """ _eid_values = [ { - 'project': project or self.project, + 'project': project or self.project_id, 'pid': new_id, 'name': name.lower(), 'external_id': eid, @@ -468,7 +468,7 @@ async def get_id_map_by_external_ids( project: ProjectId | None, ) -> dict[str, int]: """Get map of {external_id: internal_participant_id}""" - _project = project or self.project + _project = project or self.project_id if not _project: raise ValueError( 'Must provide project to get participant id map by external' diff --git a/db/python/tables/project.py b/db/python/tables/project.py index b08305044..74fa1c74f 100644 --- a/db/python/tables/project.py +++ b/db/python/tables/project.py @@ -1,20 +1,20 @@ # pylint: disable=global-statement -from typing import Any, Dict, Iterable, List +from typing import TYPE_CHECKING, Any, Tuple -from async_lru import alru_cache from databases import Database -from api.settings import is_all_access -from db.python.connect import Connection, SMConnections -from db.python.utils import ( - Forbidden, - InternalError, - NoProjectAccess, - NotFoundError, - get_logger, - to_db_json, +from db.python.utils import Forbidden, get_logger, to_db_json +from models.models.project import ( + Project, + ProjectMemberUpdate, + project_member_role_names, ) -from models.models.project import Project, ProjectId + +# Avoid circular import for type definition +if TYPE_CHECKING: + from db.python.connect import Connection +else: + Connection = object logger = get_logger() @@ -29,60 +29,21 @@ class ProjectPermissionsTable: table_name = 'project' - @staticmethod - def get_project_group_name(project_name: str, readonly: bool) -> str: - """ - Get group name for a project, for readonly / write - """ - if readonly: - return f'{project_name}-read' - return f'{project_name}-write' - def __init__( self, connection: Connection | None, - allow_full_access: bool | None = None, database_connection: Database | None = None, ): self._connection = connection - if not database_connection and not connection: - raise ValueError( - 'Must call project permissions table with either a direct ' - 'database_connection or a fully formed connection' - ) - self.connection: Database = database_connection or connection.connection - self.gtable = GroupTable(self.connection, allow_full_access=allow_full_access) - - @staticmethod - async def get_project_connection( - *, - author: str, - project_name: str, - readonly: bool, - ar_guid: str, - on_behalf_of: str | None = None, - meta: dict[str, str] | None = None, - ): - """Get a db connection from a project and user""" - # maybe it makes sense to perform permission checks here too - logger.debug(f'Authenticate connection to {project_name} with {author!r}') - - conn = await SMConnections.get_made_connection() - pt = ProjectPermissionsTable(connection=None, database_connection=conn) - - project = await pt.get_and_check_access_to_project_for_name( - user=author, project_name=project_name, readonly=readonly - ) - - return Connection( - connection=conn, - author=author, - project=project.id, - readonly=readonly, - on_behalf_of=on_behalf_of, - ar_guid=ar_guid, - meta=meta, - ) + if not database_connection: + if not connection: + raise ValueError( + 'Must call project permissions table with either a direct ' + 'database_connection or a fully formed connection' + ) + self.connection = connection.connection + else: + self.connection = database_connection async def audit_log_id(self): """ @@ -94,255 +55,125 @@ async def audit_log_id(self): ) return await self._connection.audit_log_id() - # region UNPROTECTED_GETS - - @alru_cache() - async def _get_project_rows_internal(self): - """ - Internally cached get_project_rows - """ - _query = """ - SELECT id, name, meta, dataset, read_group_id, write_group_id - FROM project - """ - rows = await self.connection.fetch_all(_query) - return list(map(Project.from_db, rows)) - - async def _get_project_id_map(self): - """ - Internally cached get_project_id_map - """ - return {p.id: p for p in await self._get_project_rows_internal()} - - async def _get_project_name_map(self) -> Dict[str, int]: - """Get {project_name: project_id} map""" - return {p.name: p.id for p in await self._get_project_rows_internal()} - - async def _get_project_by_id(self, project_id: ProjectId) -> Project: - """Get project by id""" - pmap = await self._get_project_id_map() - if project_id not in pmap: - raise NotFoundError(f'Could not find project {project_id}') - return pmap[project_id] - - async def _get_project_by_name(self, project_name: str) -> Project: - """Get project by name""" - pmap = await self._get_project_name_map() - if project_name not in pmap: - raise NotFoundError(f'Could not find project {project_name}') - return await self._get_project_by_id(pmap[project_name]) - - async def _get_projects_by_ids( - self, project_ids: Iterable[ProjectId] - ) -> List[Project]: - """Get projects by ids""" - pids = set(project_ids) - pmap = await self._get_project_id_map() - missing_pids = pids - set(pmap.keys()) - if missing_pids: - raise NotFoundError(f'Could not find projects {missing_pids}') - return [pmap[pid] for pid in pids] - - # endregion UNPROTECTED_GETS - # region AUTH - - async def get_all_projects(self, author: str): - """Get all projects""" - await self.check_project_creator_permissions(author) - return await self._get_project_rows_internal() - async def get_projects_accessible_by_user( - self, author: str, readonly=True - ) -> list[Project]: + self, user: str + ) -> tuple[dict[int, Project], dict[str, Project]]: """ Get projects that are accessible by the specified user """ - assert author - if self.gtable.allow_full_access: - return await self._get_project_rows_internal() + parameters: dict[str, str] = { + 'user': user, + 'project_creators_group_name': GROUP_NAME_PROJECT_CREATORS, + 'members_admin_group_name': GROUP_NAME_MEMBERS_ADMIN, + } - group_name = 'read_group_id' if readonly else 'write_group_id' - _query = f""" - SELECT p.id + _query = """ + -- Check what admin groups the user belongs to, if they belong + -- to project-creators then a project_admin role will be added to + -- all projects, if they belong to members-admin then a `project_member_admin` + -- role will be appended to all projects. + WITH admin_roles AS ( + SELECT + CASE (g.name) + WHEN :project_creators_group_name THEN 'project_admin' + WHEN :members_admin_group_name THEN 'project_member_admin' + END + as role + FROM `group` g + JOIN group_member gm + ON gm.group_id = g.id + WHERE gm.member = :user + AND g.name in (:project_creators_group_name, :members_admin_group_name) + ), + -- Combine together the project roles and the admin roles + project_roles AS ( + SELECT pm.project_id, pm.member, pm.role + FROM project_member pm + WHERE pm.member = :user + UNION ALL + SELECT p.id as project_id, :user as member, ar.role + FROM project p + JOIN admin_roles ar ON TRUE + ) + SELECT + p.id, + p.name, + p.meta, + p.dataset, + GROUP_CONCAT(pr.role) as roles FROM project p - INNER JOIN group_member gm ON gm.group_id = p.{group_name} - WHERE gm.member = :author + JOIN project_roles pr + ON p.id = pr.project_id + GROUP BY p.id """ - relevant_project_ids = await self.connection.fetch_all( - _query, {'author': author} - ) - projects = await self._get_projects_by_ids( - [p['id'] for p in relevant_project_ids] - ) - return projects + user_projects = await self.connection.fetch_all(_query, parameters) - async def get_and_check_access_to_project_for_id( - self, user: str, project_id: ProjectId, readonly: bool - ) -> Project: - """Get project by id""" - project = await self._get_project_by_id(project_id) - has_access = await self.gtable.check_if_member_in_group( - group_id=project.read_group_id if readonly else project.write_group_id, - member=user, - ) - if not has_access: - raise NoProjectAccess([project.name], readonly=readonly, author=user) - - return project - - async def get_and_check_access_to_project_for_name( - self, user: str, project_name: str, readonly: bool - ) -> Project: - """Get project by name + perform access checks""" - project = await self._get_project_by_name(project_name) - has_access = await self.gtable.check_if_member_in_group( - group_id=project.read_group_id if readonly else project.write_group_id, - member=user, - ) - if not has_access: - raise NoProjectAccess([project.name], readonly=readonly, author=user) + project_id_map: dict[int, Project] = {} + project_name_map: dict[str, Project] = {} - return project + for row in user_projects: + project = Project.from_db(dict(row)) + project_id_map[row['id']] = project + project_name_map[row['name']] = project - async def get_and_check_access_to_projects_for_names( - self, user: str, project_names: list[str], readonly: bool - ): - """Get projects by names + perform access checks""" - project_name_map = await self._get_project_name_map() - - # check missing_projects - missing_project_names = set(project_names) - set(project_name_map.keys()) - if missing_project_names: - raise NotFoundError( - f'Could not find projects {", ".join(missing_project_names)}' - ) + return project_id_map, project_name_map - projects = await self.get_and_check_access_to_projects_for_ids( - user=user, - project_ids=[project_name_map[name] for name in project_names], - readonly=readonly, + async def get_seqr_project_ids(self) -> list[int]: + """ + Get all projects with meta.is_seqr = true + """ + rows = await self.connection.fetch_all( + "SELECT id FROM project WHERE JSON_VALUE(meta, '$.is_seqr') = 1" ) + return [r['id'] for r in rows] - return projects - - async def get_and_check_access_to_projects_for_ids( - self, user: str, project_ids: list[ProjectId], readonly: bool - ) -> list[Project]: - """Get project by id""" - if not project_ids: - raise Forbidden( - "You don't have access to this resources, as the resource you " - "requested didn't belong to a project" - ) - - projects = await self._get_projects_by_ids(project_ids) - # do this all at once to save time - if readonly: - group_id_project_map = {p.read_group_id: p for p in projects} - else: - group_id_project_map = {p.write_group_id: p for p in projects} - - group_ids = set(gid for gid in group_id_project_map.keys() if gid) + async def check_if_member_in_group_by_name(self, group_name: str, member: str): + """Check if a user exists in the group""" - present_group_ids = await self.gtable.check_which_groups_member_has( - group_ids=group_ids, member=user - ) - missing_group_ids = group_ids - present_group_ids - - if missing_group_ids: - # so we can directly return the project names they're missing - missing_project_names = [ - group_id_project_map[gid].name or str(gid) for gid in missing_group_ids - ] - raise NoProjectAccess(missing_project_names, readonly=readonly, author=user) - - return projects - - async def check_access_to_project_id( - self, user: str, project_id: ProjectId, readonly: bool, raise_exception=True - ) -> bool: - """Check whether a user has access to project_id""" - project = await self._get_project_by_id(project_id) - has_access = await self.gtable.check_if_member_in_group( - group_id=project.read_group_id if readonly else project.write_group_id, - member=user, + _query = """ + SELECT COUNT(*) > 0 + FROM group_member gm + INNER JOIN `group` g ON g.id = gm.group_id + WHERE g.name = :group_name + AND gm.member = :member + """ + value = await self.connection.fetch_val( + _query, {'group_name': group_name, 'member': member} ) - if not has_access and raise_exception: - raise NoProjectAccess([project.name], readonly=readonly, author=user) - - return has_access - - async def check_access_to_project_ids( - self, - user: str, - project_ids: Iterable[ProjectId], - readonly: bool, - raise_exception=True, - ) -> bool: - """Check user has access to list of project_ids""" - if not project_ids: - raise Forbidden( - "You don't have access to this resources, as the resource you " - "requested didn't belong to a project" + if value not in (0, 1): + raise ValueError( + f'Unexpected value {value!r} when determining access to {group_name} ' + f'for {member}' ) - projects = await self._get_projects_by_ids(project_ids) - # do this all at once to save time - if readonly: - group_id_project_map = {p.read_group_id: p for p in projects} - else: - group_id_project_map = {p.write_group_id: p for p in projects} + return bool(value) - group_ids = set(gid for gid in group_id_project_map.keys() if gid) - present_group_ids = await self.gtable.check_which_groups_member_has( - group_ids=group_ids, member=user + async def check_project_creator_permissions(self, author: str): + """Check author has project_creator permissions""" + is_in_group = await self.check_if_member_in_group_by_name( + group_name=GROUP_NAME_PROJECT_CREATORS, member=author ) - missing_group_ids = group_ids - present_group_ids - - if missing_group_ids: - # so we can directly return the project names they're missing - missing_project_names = [ - group_id_project_map[gid].name or str(gid) for gid in missing_group_ids - ] - if raise_exception: - raise NoProjectAccess( - missing_project_names, readonly=readonly, author=user - ) - return False + if not is_in_group: + raise Forbidden(f'{author} does not have access to create a project') return True - async def check_project_creator_permissions(self, author): - """Check author has project_creator permissions""" - # check permissions in here - is_in_group = await self.gtable.check_if_member_in_group_name( - group_name=GROUP_NAME_PROJECT_CREATORS, member=author + async def check_member_admin_permissions(self, author: str): + """Check author has member_admin permissions""" + is_in_group = await self.check_if_member_in_group_by_name( + GROUP_NAME_MEMBERS_ADMIN, author ) - if not is_in_group: - raise Forbidden(f'{author} does not have access to creating project') + raise Forbidden( + f'User {author} does not have permission to edit project members' + ) return True # endregion AUTH - async def get_project_ids_from_names_and_user( - self, user: str, project_names: List[str], readonly: bool - ) -> List[ProjectId]: - """Get project ids from project names and the user""" - if not user: - raise InternalError('An internal error occurred during authorization') - - project_name_map = await self._get_project_name_map() - ordered_project_ids = [project_name_map[name] for name in project_names] - await self.check_access_to_project_ids( - user, ordered_project_ids, readonly=readonly, raise_exception=True - ) - - return ordered_project_ids - # region CREATE / UPDATE async def create_project( @@ -350,39 +181,25 @@ async def create_project( project_name: str, dataset_name: str, author: str, - check_permissions=True, ): """Create project row""" - if check_permissions: - await self.check_project_creator_permissions(author) + await self.check_project_creator_permissions(author) async with self.connection.transaction(): - audit_log_id = await self.audit_log_id() - read_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=True), - audit_log_id=audit_log_id, - ) - write_group_id = await self.gtable.create_group( - self.get_project_group_name(project_name, readonly=False), - audit_log_id=audit_log_id, - ) - _query = """\ - INSERT INTO project (name, dataset, audit_log_id, read_group_id, write_group_id) - VALUES (:name, :dataset, :audit_log_id, :read_group_id, :write_group_id) + INSERT INTO project (name, dataset, audit_log_id) + VALUES (:name, :dataset, :audit_log_id) RETURNING ID""" values = { 'name': project_name, 'dataset': dataset_name, 'audit_log_id': await self.audit_log_id(), - 'read_group_id': read_group_id, - 'write_group_id': write_group_id, } project_id = await self.connection.fetch_val(_query, values) - # pylint: disable=no-member - self._get_project_rows_internal.cache_invalidate() + if self._connection: + await self._connection.refresh_projects() return project_id @@ -392,7 +209,7 @@ async def update_project(self, project_name: str, update: dict, author: str): meta = update.get('meta') - fields: Dict[str, Any] = { + fields: dict[str, Any] = { 'audit_log_id': await self.audit_log_id(), 'name': project_name, } @@ -409,20 +226,13 @@ async def update_project(self, project_name: str, update: dict, author: str): await self.connection.execute(_query, fields) - # pylint: disable=no-member - self._get_project_rows_internal.cache_invalidate() - - async def delete_project_data( - self, project_id: int, delete_project: bool, author: str - ) -> bool: + async def delete_project_data(self, project_id: int, delete_project: bool) -> bool: """ Delete data in metamist project, requires project_creator_permissions - Can optionally delete the project also. """ if delete_project: # stop allowing delete project with analysis-runner entries raise ValueError('2024-03-08: delete_project is no longer allowed') - await self.check_project_creator_permissions(author) async with self.connection.transaction(): _query = """ @@ -465,209 +275,71 @@ async def delete_project_data( DELETE FROM participant WHERE project = :project; DELETE FROM analysis WHERE project = :project; """ - values: dict = {'project': project_id} - if delete_project: - group_ids = await self.connection.fetch_one( - """ - SELECT read_group_id, write_group_id - FROM project WHERE id = :project' - """ - ) - _query += 'DELETE FROM project WHERE id = :project;\n' - _query += 'DELETE FROM `group` WHERE id IN :group_ids\n' - values['group_ids'] = [ - group_ids['read_group_id'], - group_ids['write_group_id'], - ] await self.connection.execute(_query, {'project': project_id}) - if delete_project: - # pylint: disable=no-member - self._get_project_rows_internal.cache_invalidate() - return True - async def set_group_members(self, group_name: str, members: list[str], author: str): + async def set_project_members( + self, project: Project, members: list[ProjectMemberUpdate] + ): """ Set group members for a group (by name) """ - has_permission = await self.gtable.check_if_member_in_group_name( - GROUP_NAME_MEMBERS_ADMIN, author - ) - if not has_permission: - raise Forbidden( - f'User {author} does not have permission to add members to group {group_name}' - ) - group_id = await self.gtable.get_group_name_from_id(group_name) - await self.gtable.set_group_members( - group_id, members, audit_log_id=await self.audit_log_id() - ) - - # endregion CREATE / UPDATE - - async def get_seqr_projects(self) -> list[dict[str, Any]]: - """ - Get all projects with meta.is_seqr = true - """ - - all_projects = await self._get_project_rows_internal() - seqr_projects = [p for p in all_projects if p.meta.get('is_seqr')] - return seqr_projects - - # gruo - - -class GroupTable: - """ - Capture Analysis table operations and queries - """ - - table_name = 'group' + async with self.connection.transaction(): - def __init__(self, connection: Database, allow_full_access: bool = None): - if not isinstance(connection, Database): - raise ValueError( - f'Invalid type connection, expected Database, got {type(connection)}, ' - 'did you forget to call connection.connection?' + # Get existing rows so that we can keep the existing audit log ids + existing_rows = await self.connection.fetch_all( + """ + SELECT project_id, member, role, audit_log_id + FROM project_member + WHERE project_id = :project_id + """, + {'project_id': project.id}, ) - self.connection: Database = connection - self.allow_full_access = ( - allow_full_access if allow_full_access is not None else is_all_access() - ) - async def get_group_members(self, group_id: int) -> set[str]: - """Get project IDs for sampleIds (mostly for checking auth)""" - _query = """ - SELECT member - FROM `group` - WHERE id = :group_id - """ - rows = await self.connection.fetch_all(_query, {'group_id': group_id}) - members = set(r['member'] for r in rows) - return members - - async def get_group_name_from_id(self, name: str) -> int: - """Get group name to group id""" - _query = """ - SELECT id, name - FROM `group` - WHERE name = :name - """ - row = await self.connection.fetch_one(_query, {'name': name}) - if not row: - raise NotFoundError(f'Could not find group {name}') - return row['id'] - - async def get_group_name_to_ids(self, names: list[str]) -> dict[str, int]: - """Get group name to group id""" - _query = """ - SELECT id, name - FROM `group` - WHERE name IN :names - """ - rows = await self.connection.fetch_all(_query, {'names': names}) - return {r['name']: r['id'] for r in rows} - - async def check_if_member_in_group(self, group_id: int, member: str) -> bool: - """Check if a member is in a group""" - if self.allow_full_access: - return True + audit_log_id_map: dict[Tuple[str, str], int | None] = { + (r['member'], r['role']): r['audit_log_id'] for r in existing_rows + } - _query = """ - SELECT COUNT(*) > 0 - FROM group_member gm - WHERE gm.group_id = :group_id - AND gm.member = :member - """ - value = await self.connection.fetch_val( - _query, {'group_id': group_id, 'member': member} - ) - if value not in (0, 1): - raise ValueError( - f'Unexpected value {value!r} when determining access to group with ID ' - f'{group_id} for {member}' + # delete existing rows for project + await self.connection.execute( + """ + DELETE FROM project_member + WHERE project_id = :project_id + """, + {'project_id': project.id}, ) - return bool(value) - - async def check_if_member_in_group_name(self, group_name: str, member: str) -> bool: - """Check if a member is in a group""" - if self.allow_full_access: - return True - _query = """ - SELECT COUNT(*) > 0 - FROM group_member gm - INNER JOIN `group` g ON g.id = gm.group_id - WHERE g.name = :group_name - AND gm.member = :member - """ - value = await self.connection.fetch_val( - _query, {'group_name': group_name, 'member': member} - ) - if value not in (0, 1): - raise ValueError( - f'Unexpected value {value!r} when determining access to {group_name} ' - f'for {member}' + new_audit_log_id = await self.audit_log_id() + + db_members: list[dict[str, str]] = [] + + for m in members: + db_members.extend([{'member': m.member, 'role': r} for r in m.roles]) + + await self.connection.execute_many( + """ + INSERT INTO project_member + (project_id, member, role, audit_log_id) + VALUES (:project_id, :member, :role, :audit_log_id); + """, + [ + { + 'project_id': project.id, + 'member': m['member'], + 'role': m['role'], + 'audit_log_id': audit_log_id_map.get( + (m['member'], m['role']), new_audit_log_id + ), + } + for m in db_members + if m['role'] in project_member_role_names + ], ) - return bool(value) - - async def check_which_groups_member_has( - self, group_ids: set[int], member: str - ) -> set[int]: - """ - Check which groups a member has - """ - if self.allow_full_access: - return group_ids - - _query = """ - SELECT gm.group_id as gid - FROM group_member gm - WHERE gm.member = :member AND gm.group_id IN :group_ids - """ - results = await self.connection.fetch_all( - _query, {'group_ids': group_ids, 'member': member} - ) - return set(r['gid'] for r in results) - - async def create_group(self, name: str, audit_log_id: int) -> int: - """Create a new group""" - _query = """ - INSERT INTO `group` (name, audit_log_id) - VALUES (:name, :audit_log_id) - RETURNING id - """ - return await self.connection.fetch_val( - _query, {'name': name, 'audit_log_id': audit_log_id} - ) + if self._connection: + await self._connection.refresh_projects() - async def set_group_members( - self, group_id: int, members: list[str], audit_log_id: int - ): - """ - Set group members for a group (by id) - """ - await self.connection.execute( - """ - DELETE FROM group_member - WHERE group_id = :group_id - """, - {'group_id': group_id}, - ) - await self.connection.execute_many( - """ - INSERT INTO group_member (group_id, member, audit_log_id) - VALUES (:group_id, :member, :audit_log_id) - """, - [ - { - 'group_id': group_id, - 'member': member, - 'audit_log_id': audit_log_id, - } - for member in members - ], - ) + # endregion CREATE / UPDATE diff --git a/db/python/tables/sample.py b/db/python/tables/sample.py index a53ce30ce..54135a957 100644 --- a/db/python/tables/sample.py +++ b/db/python/tables/sample.py @@ -247,7 +247,7 @@ async def insert_sample( ('audit_log_id', audit_log_id), ('sample_parent_id', sample_parent_id), ('sample_root_id', sample_root_id), - ('project', project or self.project), + ('project', project or self.project_id), ] keys = [k for k, _ in kv_pairs] @@ -270,7 +270,7 @@ async def insert_sample( """ _eid_values = [ { - 'project': project or self.project, + 'project': project or self.project_id, 'id': id_of_new_sample, 'name': name.lower(), 'external_id': eid, @@ -548,7 +548,8 @@ async def get_sample_id_map_by_external_ids( WHERE external_id IN :external_ids AND project = :project """ rows = await self.connection.fetch_all( - _query, {'external_ids': external_ids, 'project': project or self.project} + _query, + {'external_ids': external_ids, 'project': project or self.project_id}, ) sample_id_map = {el[1]: el[0] for el in rows} @@ -584,7 +585,7 @@ async def get_all_sample_id_map_by_internal_ids( rows = await self.connection.fetch_all( _query, { - 'project': project or self.project, + 'project': project or self.project_id, 'PRIMARY_EXTERNAL_ORG': PRIMARY_EXTERNAL_ORG, }, ) @@ -641,7 +642,7 @@ async def get_samples_with_missing_participants_by_internal_id( _, samples = await self.query( SampleFilter( participant_id=GenericFilter(isnull=True), - project=GenericFilter(eq=project or self.project), + project=GenericFilter(eq=project or self.project_id), ) ) return samples diff --git a/db/python/tables/sequencing_group.py b/db/python/tables/sequencing_group.py index 829be7057..4d77a40d4 100644 --- a/db/python/tables/sequencing_group.py +++ b/db/python/tables/sequencing_group.py @@ -278,7 +278,7 @@ async def get_all_sequencing_group_ids_by_sample_ids_by_type( INNER JOIN sequencing_group sg ON s.id = sg.sample_id WHERE project = :project """ - rows = await self.connection.fetch_all(_query, {'project': self.project}) + rows = await self.connection.fetch_all(_query, {'project': self.project_id}) sequencing_group_ids_by_sample_ids_by_type: dict[int, dict[str, list[int]]] = ( defaultdict(lambda: defaultdict(list)) ) @@ -306,7 +306,7 @@ async def get_participant_ids_and_sequencing_group_ids_for_sequencing_type( rows = list( await self.connection.fetch_all( - _query, {'seqtype': sequencing_type, 'project': self.project} + _query, {'seqtype': sequencing_type, 'project': self.project_id} ) ) diff --git a/db/python/utils.py b/db/python/utils.py index a928c2070..70c3b49c9 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -2,7 +2,7 @@ import logging import os import re -from typing import Sequence, TypeVar +from typing import Any, TypeVar T = TypeVar('T') X = TypeVar('X') @@ -60,20 +60,18 @@ class NoProjectAccess(Forbidden): def __init__( self, - project_names: Sequence[str] | None, + project_names: list[str], author: str, - *args, - readonly: bool | None = None, + allowed_roles: list[str], + *args: tuple[Any, ...], ): project_names_str = ( ', '.join(repr(p) for p in project_names) if project_names else '' ) - access_type = '' - if readonly is False: - access_type = 'write ' + required_roles_str = ' or '.join(allowed_roles) super().__init__( - f'{author} does not have {access_type}access to resources from the ' + f'{author} does not have {required_roles_str} access to resources from the ' f'following project(s), or they may not exist: {project_names_str}', *args, ) @@ -108,9 +106,13 @@ def get_logger(): return _logger +def from_db_json(text): + """Convert DB's JSON text to Python object""" + return json.loads(text) + + def to_db_json(val): """Convert val to json for DB""" - # return psycopg2.extras.Json(val) return json.dumps(val) diff --git a/deploy/python/version.txt b/deploy/python/version.txt index 21c8c7b46..0ee843cc6 100644 --- a/deploy/python/version.txt +++ b/deploy/python/version.txt @@ -1 +1 @@ -7.1.1 +7.2.0 diff --git a/docs/installation.md b/docs/installation.md index cb2290afb..20608a80b 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -288,7 +288,7 @@ export SM_DEV_DB_PORT=3306 # or 3307 You'll want to set the following environment variables (permanently) in your local development environment. -The `SM_ENVIRONMENT`, `SM_LOCALONLY_DEFAULTUSER` and `SM_ALLOWALLACCESS` environment variables allow access to a local metamist server without providing a bearer token. +The `SM_ENVIRONMENT` and `SM_LOCALONLY_DEFAULTUSER` environment variables allow access to a local metamist server without providing a bearer token. This will allow you to test the front-end components that access data. This happens automatically on the production instance through the Google identity-aware-proxy. @@ -296,12 +296,19 @@ This will allow you to test the front-end components that access data. This happ # ensures the SWAGGER page points to your local: (localhost:8000/docs) # and ensures if you use the PythonAPI, it also points to your local export SM_ENVIRONMENT=LOCAL -# skips permission checks in your local environment -export SM_ALLOWALLACCESS=true # uses your username as the "author" in requests export SM_LOCALONLY_DEFAULTUSER=$(whoami) ``` +To allow the system to be bootstrapped and create the initial project, you'll need to add yourself to the two admin groups that allow creating projects and updating project members: + +```sql +INSERT INTO group_member(group_id, member) +SELECT id, '' +FROM `group` WHERE name IN('project-creators', 'members-admin') + +``` + With those variables set, it is a good time to populate some test data if this is your first time running this server: ```bash @@ -335,7 +342,6 @@ The following `launch.json` is a good base to debug the web server in VS Code: "module": "api.server", "justMyCode": false, "env": { - "SM_ALLOWALLACCESS": "true", "SM_LOCALONLY_DEFAULTUSER": "-local", "SM_ENVIRONMENT": "local", "SM_DEV_DB_USER": "sm_api", diff --git a/models/models/project.py b/models/models/project.py index e1243f274..df9d09c4b 100644 --- a/models/models/project.py +++ b/models/models/project.py @@ -1,23 +1,67 @@ -from typing import Optional +from enum import Enum +from typing import Any, Optional + +from pydantic import field_serializer from models.base import SMBase, parse_sql_dict ProjectId = int +ProjectMemberRole = Enum( + 'ProjectMemberRole', + [ + 'reader', + 'contributor', + 'writer', + 'project_admin', + 'project_member_admin', + ], +) + + +# These roles have read access to a project +ReadAccessRoles = { + ProjectMemberRole.reader, + ProjectMemberRole.contributor, + ProjectMemberRole.writer, +} + +# Only write has full write access +FullWriteAccessRoles = {ProjectMemberRole.writer} +project_member_role_names = [r.name for r in ProjectMemberRole] + class Project(SMBase): """Row for project in 'project' table""" - id: Optional[ProjectId] = None - name: Optional[str] = None - dataset: Optional[str] = None - meta: Optional[dict] = None - read_group_id: Optional[int] = None - write_group_id: Optional[int] = None + id: ProjectId + name: str + dataset: str + meta: Optional[dict[str, Any]] = None + roles: set[ProjectMemberRole] + """The roles that the current user has within the project""" + + @field_serializer('roles') + def serialize_roles(self, roles: set[ProjectMemberRole]): + """convert roles into a form that can be returned from the API""" + return [r.name for r in roles] @staticmethod def from_db(kwargs): """From DB row, with db keys""" kwargs = dict(kwargs) kwargs['meta'] = parse_sql_dict(kwargs.get('meta')) or {} + + # Sanitise role names and convert to enum members + role_list: list[str] = kwargs['roles'].split(',') if kwargs.get('roles') else [] + kwargs['roles'] = { + ProjectMemberRole[r] for r in role_list if r in project_member_role_names + } return Project(**kwargs) + + +class ProjectMemberUpdate(SMBase): + """Item included in list of project member updates""" + + member: str + roles: list[str] diff --git a/scripts/sync_seqr.py b/scripts/sync_seqr.py index 45b415da1..fa685666f 100644 --- a/scripts/sync_seqr.py +++ b/scripts/sync_seqr.py @@ -353,9 +353,11 @@ def _parse_consanguity(consanguity): def process_row(row): return { - seqr_key: key_processor[sm_key](row[sm_key]) - if sm_key in key_processor - else row[sm_key] + seqr_key: ( + key_processor[sm_key](row[sm_key]) + if sm_key in key_processor + else row[sm_key] + ) for seqr_key, sm_key in seqr_map.items() if sm_key in row } @@ -558,7 +560,7 @@ async def _make_update_igv_call(update): if not updates: continue print( - f'{dataset} :: Updating CRAMs {idx * chunk_size + 1} -> {(min((idx + 1 ) * chunk_size, len(all_updates)))} (/{len(all_updates)})' + f'{dataset} :: Updating CRAMs {idx * chunk_size + 1} -> {(min((idx + 1) * chunk_size, len(all_updates)))} (/{len(all_updates)})' ) responses = await asyncio.gather( diff --git a/setup.py b/setup.py index c2078672e..64203f264 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( name=PKG, # This tag is automatically updated by bump2version - version='7.1.1', + version='7.2.0', description='Python API for interacting with the Sample API system', long_description=readme, long_description_content_type='text/markdown', diff --git a/test/data/generate_data.py b/test/data/generate_data.py index 02b44b53b..079d8a1cd 100755 --- a/test/data/generate_data.py +++ b/test/data/generate_data.py @@ -3,7 +3,9 @@ import argparse import asyncio import datetime +import os import random +import sys from pathlib import Path from pprint import pprint from uuid import uuid4 @@ -76,6 +78,19 @@ async def main(ped_path=default_ped_location, project='greek-myth'): await papi.create_project_async( name=project, dataset=project, create_test_project=False ) + default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') + if not default_user: + print( + 'SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data' + ) + sys.exit(1) + + await papi.update_project_members_async( + project=project, + project_member_update=[ + {'member': default_user, 'roles': ['reader', 'writer']} + ], + ) with open(ped_path, encoding='utf-8') as f: # skip the first line @@ -108,7 +123,7 @@ def generate_random_number_within_distribution(): )[0] samples = [] - sample_id_index = 1003 + sample_id_index = 10000 for participant_eid in participant_eids: pid = id_map[participant_eid] diff --git a/test/data/generate_seqr_project_data.py b/test/data/generate_seqr_project_data.py index 2daa6c31f..026c53e31 100644 --- a/test/data/generate_seqr_project_data.py +++ b/test/data/generate_seqr_project_data.py @@ -4,6 +4,7 @@ import csv import datetime import logging +import os import random import sys import tempfile @@ -13,9 +14,9 @@ from metamist.apis import AnalysisApi, FamilyApi, ParticipantApi, ProjectApi, SampleApi from metamist.graphql import gql, query_async from metamist.model.analysis import Analysis -from metamist.model.analysis_status import AnalysisStatus from metamist.models import AssayUpsert, SampleUpsert, SequencingGroupUpsert from metamist.parser.generic_parser import chunk +from models.enums import AnalysisStatus PRIMARY_EXTERNAL_ORG = '' @@ -93,6 +94,7 @@ class ped_row: """The pedigree row class""" + def __init__(self, values): ( self.family_id, @@ -158,24 +160,28 @@ def generate_pedigree_rows(num_families=1): individual_id = generate_random_id(used_ids) if parent_sex == 1: rows.append( - ped_row([ - family_id, - individual_id, - parent_id, - '', - random.choice([0, 1]), - 2,] + ped_row( + [ + family_id, + individual_id, + parent_id, + '', + random.choice([0, 1]), + 2, + ] ) ) else: rows.append( - ped_row([ - family_id, - individual_id, - '', - parent_id, - random.choice([0, 1]), - 2,] + ped_row( + [ + family_id, + individual_id, + '', + parent_id, + random.choice([0, 1]), + 2, + ] ) ) @@ -203,13 +209,24 @@ def generate_pedigree_rows(num_families=1): 0 ] # Randomly assign affected status rows.append( - ped_row([family_id, individual_id, paternal_id, maternal_id, sex, affected]) + ped_row( + [ + family_id, + individual_id, + paternal_id, + maternal_id, + sex, + affected, + ] + ) ) return rows -def generate_sequencing_type(count_distribution: dict[int, float], sequencing_types: list[str]): +def generate_sequencing_type( + count_distribution: dict[int, float], sequencing_types: list[str] +): """Return a random length of random sequencing types""" k = random.choices( list(count_distribution.keys()), @@ -274,7 +291,9 @@ async def generate_sample_entries( samples = [] for participant_eid, participant_id in participant_id_map.items(): - nsamples = generate_random_number_within_distribution(default_count_probabilities) + nsamples = generate_random_number_within_distribution( + default_count_probabilities + ) for i in range(nsamples): sample = SampleUpsert( external_ids={PRIMARY_EXTERNAL_ORG: f'{participant_eid}_{i+1}'}, @@ -291,7 +310,9 @@ async def generate_sample_entries( ) samples.append(sample) - for stype in generate_sequencing_type(default_count_probabilities, sequencing_types): + for stype in generate_sequencing_type( + default_count_probabilities, sequencing_types + ): facility = random.choice( [ 'Amazing sequence centre', @@ -311,13 +332,17 @@ async def generate_sample_entries( assays=[], ) sample.sequencing_groups.append(sg) - for _ in range(generate_random_number_within_distribution(default_count_probabilities)): + for _ in range( + generate_random_number_within_distribution( + default_count_probabilities + ) + ): sg.assays.append( AssayUpsert( type='sequencing', meta={ 'facility': facility, - 'reads' : [], + 'reads': [], 'coverage': f'{random.choice([30, 90, 300, 9000, "?"])}x', 'sequencing_type': stype, 'sequencing_technology': stechnology, @@ -341,7 +366,7 @@ async def generate_cram_analyses(project: str, analyses_to_insert: list[Analysis # Randomly allocate some of the sequencing groups to be aligned aligned_sgs = random.sample( sequencing_groups, - k=random.randint(int(len(sequencing_groups)/2), len(sequencing_groups)) + k=random.randint(int(len(sequencing_groups) / 2), len(sequencing_groups)), ) # Insert completed CRAM analyses for the aligned sequencing groups @@ -353,11 +378,14 @@ async def generate_cram_analyses(project: str, analyses_to_insert: list[Analysis status=AnalysisStatus('completed'), output=f'FAKE://{project}/crams/{sg["id"]}.cram', timestamp_completed=( - datetime.datetime.now() - datetime.timedelta(days=random.randint(1, 15)) + datetime.datetime.now() + - datetime.timedelta(days=random.randint(1, 15)) ).isoformat(), meta={ # random size between 5, 25 GB - 'size': random.randint(5 * 1024, 25 * 1024) * 1024 * 1024, + 'size': random.randint(5 * 1024, 25 * 1024) + * 1024 + * 1024, }, ) for sg in aligned_sgs @@ -367,7 +395,9 @@ async def generate_cram_analyses(project: str, analyses_to_insert: list[Analysis return aligned_sgs -async def generate_joint_called_analyses(project: str, aligned_sgs: list[dict], analyses_to_insert: list[Analysis]): +async def generate_joint_called_analyses( + project: str, aligned_sgs: list[dict], analyses_to_insert: list[Analysis] +): """ Selects a subset of the aligned sequencing groups for the input project and generates joint-called AnnotateDataset and ES-index analysis entries for them. @@ -375,7 +405,9 @@ async def generate_joint_called_analyses(project: str, aligned_sgs: list[dict], seq_type_to_sg_list = { 'genome': [sg['id'] for sg in aligned_sgs if sg['type'] == 'genome'], 'exome': [sg['id'] for sg in aligned_sgs if sg['type'] == 'exome'], - 'transcriptome': [sg['id'] for sg in aligned_sgs if sg['type'] == 'transcriptome'] + 'transcriptome': [ + sg['id'] for sg in aligned_sgs if sg['type'] == 'transcriptome' + ], } for seq_type, sg_list in seq_type_to_sg_list.items(): if not sg_list: @@ -397,7 +429,7 @@ async def generate_joint_called_analyses(project: str, aligned_sgs: list[dict], status=AnalysisStatus('completed'), output=f'FAKE::{project}-{seq_type}-es-{datetime.date.today()}', meta={'stage': 'MtToEs', 'sequencing_type': seq_type}, - ) + ), ] ) @@ -422,9 +454,25 @@ async def main(): await papi.create_project_async( name=project, dataset=project, create_test_project=False ) + + default_user = os.getenv('SM_LOCALONLY_DEFAULTUSER') + if not default_user: + print( + 'SM_LOCALONLY_DEFAULTUSER env var is not set, please set it before generating data' + ) + sys.exit(1) + + await papi.update_project_members_async( + project=project, + project_member_update=[ + {'member': default_user, 'roles': ['reader', 'writer']} + ], + ) + logging.info(f'Created project "{project}"') await papi.update_project_async( - project=project, body={'meta': {'is_seqr': 'true'}}, + project=project, + body={'meta': {'is_seqr': 'true'}}, ) logging.info(f'Set {project} as seqr project') @@ -438,7 +486,9 @@ async def main(): for analyses in chunk(analyses_to_insert, 50): logging.info(f'Inserting {len(analyses)} analysis entries') - await asyncio.gather(*[aapi.create_analysis_async(project, a) for a in analyses]) + await asyncio.gather( + *[aapi.create_analysis_async(project, a) for a in analyses] + ) if __name__ == '__main__': diff --git a/test/test_assay.py b/test/test_assay.py index ec735b128..98a938ea1 100644 --- a/test/test_assay.py +++ b/test/test_assay.py @@ -59,7 +59,7 @@ async def test_not_found_assay(self): @run_as_sync async def get(): - return await self.assaylayer.get_assay_by_id(-1, check_project_id=False) + return await self.assaylayer.get_assay_by_id(-1) self.assertRaises(NotFoundError, get) @@ -84,9 +84,7 @@ async def test_upsert_assay(self): ) ) - assay = await self.assaylayer.get_assay_by_id( - assay_id=upserted_assay.id, check_project_id=False - ) + assay = await self.assaylayer.get_assay_by_id(assay_id=upserted_assay.id) self.assertEqual(upserted_assay.id, assay.id) self.assertEqual(self.sample_id_raw, int(assay.sample_id)) @@ -475,9 +473,7 @@ async def test_update(self): ) ) - update_assay = await self.assaylayer.get_assay_by_id( - assay_id=assay.id, check_project_id=False - ) + update_assay = await self.assaylayer.get_assay_by_id(assay_id=assay.id) self.assertEqual(assay.id, update_assay.id) self.assertEqual(self.sample_id_raw, int(update_assay.sample_id)) @@ -506,8 +502,7 @@ async def test_update_type(self): # cycle through all statuses, and check that works await self.assaylayer.upsert_assay( - AssayUpsertInternal(id=assay.id, type='metabolomics'), - check_project_id=False, + AssayUpsertInternal(id=assay.id, type='metabolomics') ) row_to_check = await self.connection.connection.fetch_one( 'SELECT type FROM assay WHERE id = :id', diff --git a/test/test_pedigree.py b/test/test_pedigree.py index f263ffd32..328651e95 100644 --- a/test/test_pedigree.py +++ b/test/test_pedigree.py @@ -24,7 +24,7 @@ async def test_import_get_pedigree(self): ) pedigree_dicts = await fl.get_pedigree( - project=self.connection.project, + project=self.connection.project_id, replace_with_participant_external_ids=True, replace_with_family_external_ids=True, ) @@ -60,7 +60,7 @@ async def test_pedigree_without_family(self): ) rows = await fl.get_pedigree( - project=self.connection.project, + project=self.connection.project_id, include_participants_not_in_families=True, replace_with_participant_external_ids=True, ) diff --git a/test/test_project_groups.py b/test/test_project_groups.py index 0f289130a..e7db46e9e 100644 --- a/test/test_project_groups.py +++ b/test/test_project_groups.py @@ -6,7 +6,13 @@ GROUP_NAME_PROJECT_CREATORS, ProjectPermissionsTable, ) -from db.python.utils import Forbidden, NotFoundError +from db.python.utils import Forbidden +from models.models.project import ( + FullWriteAccessRoles, + ProjectMemberRole, + ProjectMemberUpdate, + ReadAccessRoles, +) class TestGroupAccess(DbIsolatedTest): @@ -20,9 +26,9 @@ async def setUp(self): super().setUp() # specifically required to test permissions - self.pttable = ProjectPermissionsTable(self.connection, False) + self.pttable = ProjectPermissionsTable(self.connection) - async def _add_group_member_direct(self, group_name): + async def _add_group_member_direct(self, group_name: str): """ Helper method to directly add members to group with name """ @@ -42,87 +48,6 @@ async def _add_group_member_direct(self, group_name): }, ) - @run_as_sync - async def test_group_set_members_failed_no_permission(self): - """ - Test that a user without permission cannot set members - """ - with self.assertRaises(Forbidden): - await self.pttable.set_group_members( - 'another-test-project', ['user1'], self.author - ) - - @run_as_sync - async def test_group_set_members_failed_not_exists(self): - """ - Test that a user with permission, cannot set members - for a group that doesn't exist - """ - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - with self.assertRaises(NotFoundError): - await self.pttable.set_group_members( - 'another-test-project', ['user1'], self.author - ) - - @run_as_sync - async def test_group_set_members_succeeded(self): - """ - Test that a user with permission, can set members for a group that exists - """ - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - - g = str(uuid.uuid4()) - await self.pttable.gtable.create_group(g, await self.audit_log_id()) - - self.assertFalse( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user1') - ) - self.assertFalse( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user2') - ) - - await self.pttable.set_group_members( - group_name=g, members=['user1', 'user2'], author=self.author - ) - - self.assertTrue( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user1') - ) - self.assertTrue( - await self.pttable.gtable.check_if_member_in_group_name(g, 'user2') - ) - - @run_as_sync - async def test_check_which_groups_member_is_missing(self): - """Test the check_which_groups_member_has function""" - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - - group = str(uuid.uuid4()) - gid = await self.pttable.gtable.create_group(group, await self.audit_log_id()) - present_gids = await self.pttable.gtable.check_which_groups_member_has( - {gid}, self.author - ) - missing_gids = {gid} - present_gids - self.assertEqual(1, len(missing_gids)) - self.assertEqual(gid, missing_gids.pop()) - - @run_as_sync - async def test_check_which_groups_member_is_missing_none(self): - """Test the check_which_groups_member_has function""" - await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) - - group = str(uuid.uuid4()) - gid = await self.pttable.gtable.create_group(group, await self.audit_log_id()) - await self.pttable.gtable.set_group_members( - gid, [self.author], audit_log_id=await self.audit_log_id() - ) - present_gids = await self.pttable.gtable.check_which_groups_member_has( - group_ids={gid}, member=self.author - ) - missing_gids = {gid} - present_gids - - self.assertEqual(0, len(missing_gids)) - @run_as_sync async def test_project_creators_failed(self): """ @@ -145,11 +70,13 @@ async def test_project_create_succeed(self): project_id = await self.pttable.create_project(g, g, self.author) # pylint: disable=protected-access - project = await self.pttable._get_project_by_id(project_id) + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) - # test that the group names make sense - self.assertIsNotNone(project.read_group_id) - self.assertIsNotNone(project.write_group_id) + project = project_id_map.get(project_id) + assert project + self.assertEqual(project.name, g) class TestProjectAccess(DbIsolatedTest): @@ -161,9 +88,12 @@ async def setUp(self): super().setUp() # specifically required to test permissions - self.pttable = ProjectPermissionsTable(self.connection, False) + self.pttable = ProjectPermissionsTable(self.connection) - async def _add_group_member_direct(self, group_name): + async def _add_group_member_direct( + self, + group_name: str, + ): """ Helper method to directly add members to group with name """ @@ -193,13 +123,13 @@ async def test_no_project_access(self): project_id = await self.pttable.create_project(g, g, self.author) with self.assertRaises(Forbidden): - await self.pttable.get_and_check_access_to_project_for_id( - user=self.author, project_id=project_id, readonly=True + self.connection.check_access_to_projects_for_ids( + project_ids=[project_id], allowed_roles=ReadAccessRoles ) with self.assertRaises(Forbidden): - await self.pttable.get_and_check_access_to_project_for_name( - user=self.author, project_name=g, readonly=True + self.connection.get_and_check_access_to_projects_for_names( + project_names=[g], allowed_roles=ReadAccessRoles ) @run_as_sync @@ -213,21 +143,64 @@ async def test_project_access_success(self): g = str(uuid.uuid4()) pid = await self.pttable.create_project(g, g, self.author) - await self.pttable.set_group_members( - group_name=self.pttable.get_project_group_name(g, readonly=True), - members=[self.author], - author=self.author, + + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) + project = project_id_map.get(pid) + assert project + await self.pttable.set_project_members( + project=project, + members=[ProjectMemberUpdate(member=self.author, roles=['reader'])], ) - project_for_id = await self.pttable.get_and_check_access_to_project_for_id( - user=self.author, project_id=pid, readonly=True + project_for_id = self.connection.get_and_check_access_to_projects_for_ids( + project_ids=[pid], allowed_roles=ReadAccessRoles ) - self.assertEqual(pid, project_for_id.id) + user_project_for_id = next(p for p in project_for_id) + self.assertEqual(pid, user_project_for_id.id) - project_for_name = await self.pttable.get_and_check_access_to_project_for_name( - user=self.author, project_name=g, readonly=True + project_for_name = self.connection.get_and_check_access_to_projects_for_names( + project_names=[g], allowed_roles=ReadAccessRoles ) - self.assertEqual(pid, project_for_name.id) + user_project_for_name = next(p for p in project_for_name) + self.assertEqual(g, user_project_for_name.name) + + @run_as_sync + async def test_project_access_insufficient(self): + """ + Test that a user with access to a project will be disallowed if their access is + not sufficient + """ + await self._add_group_member_direct(GROUP_NAME_PROJECT_CREATORS) + await self._add_group_member_direct(GROUP_NAME_MEMBERS_ADMIN) + + g = str(uuid.uuid4()) + + pid = await self.pttable.create_project(g, g, self.author) + + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) + project = project_id_map.get(pid) + assert project + # Give the user read access to the project + await self.pttable.set_project_members( + project=project, + members=[ProjectMemberUpdate(member=self.author, roles=['reader'])], + ) + + # But require Write access + + with self.assertRaises(Forbidden): + self.connection.check_access_to_projects_for_ids( + project_ids=[project.id], allowed_roles=FullWriteAccessRoles + ) + + with self.assertRaises(Forbidden): + self.connection.get_and_check_access_to_projects_for_names( + project_names=[g], allowed_roles=FullWriteAccessRoles + ) @run_as_sync async def test_get_my_projects(self): @@ -241,15 +214,27 @@ async def test_get_my_projects(self): pid = await self.pttable.create_project(g, g, self.author) - await self.pttable.set_group_members( - group_name=self.pttable.get_project_group_name(g, readonly=True), - members=[self.author], - author=self.author, + project_id_map, _ = await self.pttable.get_projects_accessible_by_user( + user=self.author + ) + project = project_id_map.get(pid) + assert project + # Give the user read access to the project + await self.pttable.set_project_members( + project=project, + members=[ProjectMemberUpdate(member=self.author, roles=['contributor'])], + ) + + project_id_map, project_name_map = ( + await self.pttable.get_projects_accessible_by_user(user=self.author) ) - projects = await self.pttable.get_projects_accessible_by_user( - author=self.author + # Get projects with at least a read access role + my_projects = self.connection.projects_with_role( + {ProjectMemberRole.contributor} ) + print(my_projects) - self.assertEqual(1, len(projects)) - self.assertEqual(pid, projects[0].id) + self.assertEqual(len(project_id_map), len(project_name_map)) + self.assertEqual(len(my_projects), 1) + self.assertEqual(pid, my_projects[0].id) diff --git a/test/test_sample.py b/test/test_sample.py index b967fdd3e..e50020429 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -51,7 +51,7 @@ async def test_get_sample(self): ) ) - sample = await self.slayer.get_by_id(s.id, check_project_id=False) + sample = await self.slayer.get_by_id(s.id) self.assertEqual('blood', sample.type) self.assertDictEqual(meta_dict, sample.meta) @@ -107,7 +107,7 @@ async def test_update_sample(self): SampleUpsertInternal(id=s.id, external_ids=new_external_id_dict) ) - sample = await self.slayer.get_by_id(s.id, check_project_id=False) + sample = await self.slayer.get_by_id(s.id) self.assertDictEqual(new_external_id_dict, sample.external_ids) diff --git a/test/test_search.py b/test/test_search.py index 6bf90d00f..ad3d3c07f 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -224,7 +224,7 @@ async def test_search_mixed(self): ) all_results = await self.schlay.search( - query='X:', project_ids=[self.connection.project] + query='X:', project_ids=[self.connection.project_id] ) self.assertEqual(3, len(all_results)) diff --git a/test/testbase.py b/test/testbase.py index a61ec4779..7dc08049d 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -17,7 +17,6 @@ from api.graphql.loaders import get_context # type: ignore from api.graphql.schema import schema # type: ignore -from api.settings import set_all_access from db.python.connect import ( TABLES_ORDERED_BY_FK_DEPS, Connection, @@ -25,7 +24,7 @@ SMConnections, ) from db.python.tables.project import ProjectPermissionsTable -from models.models.project import ProjectId +from models.models.project import Project, ProjectId, ProjectMemberUpdate # use this to determine where the db directory is relatively, # as pycharm runs in "test/" folder, and GH runs them in git root @@ -81,6 +80,8 @@ class DbTest(unittest.TestCase): author: str project_id: ProjectId project_name: str + project_id_map: dict[ProjectId, Project] + project_name_map: dict[str, Project] @classmethod def setUpClass(cls) -> None: @@ -94,9 +95,10 @@ async def setup(): Then you can destroy the database within tearDownClass as all tests have been completed. """ + # Set environment to test + os.environ['SM_ENVIRONMENT'] = 'test' logger = logging.getLogger() try: - set_all_access(True) db = MySqlContainer('mariadb:11.2.2', password='test') port_to_expose = find_free_port() # override the default port to map the container to @@ -141,23 +143,56 @@ async def setup(): formed_connection = Connection( connection=sm_db, author=cls.author, - readonly=False, + project_id_map={}, + project_name_map={}, on_behalf_of=None, ar_guid=None, project=None, ) + + # Add the test user to the admin groups + await formed_connection.connection.execute( + f""" + INSERT INTO group_member(group_id, member) + SELECT id, '{cls.author}' + FROM `group` WHERE name IN('project-creators', 'members-admin') + """ + ) + ppt = ProjectPermissionsTable( connection=formed_connection, - allow_full_access=True, ) cls.project_name = 'test' cls.project_id = await ppt.create_project( project_name=cls.project_name, dataset_name=cls.project_name, author=cls.author, - check_permissions=False, ) - formed_connection.project = cls.project_id + + _, initial_project_name_map = await ppt.get_projects_accessible_by_user( + user=cls.author + ) + created_project = initial_project_name_map.get(cls.project_name) + assert created_project + + await ppt.set_project_members( + project=created_project, + members=[ + ProjectMemberUpdate( + member=cls.author, roles=['reader', 'writer'] + ) + ], + ) + + # Get the new project map now that project membership is updated + project_id_map, project_name_map = ( + await ppt.get_projects_accessible_by_user(user=cls.author) + ) + + cls.project_id_map = project_id_map + cls.project_name_map = project_name_map + + # formed_connection.project = cls.project_id except subprocess.CalledProcessError as e: logging.exception(e) @@ -189,9 +224,10 @@ def setUp(self) -> None: # audit_log ID for each test self.connection = Connection( connection=self._connection, - project=self.project_id, + project=self.project_id_map.get(self.project_id), author=self.author, - readonly=False, + project_id_map=self.project_id_map, + project_name_map=self.project_name_map, ar_guid=None, on_behalf_of=None, ) @@ -266,7 +302,13 @@ async def tearDown(self) -> None: async def clear_database(self): """Clear the database of all data, except for project + group tables""" - ignore = {'DATABASECHANGELOG', 'DATABASECHANGELOGLOCK', 'project', 'group'} + ignore = { + 'DATABASECHANGELOG', + 'DATABASECHANGELOGLOCK', + 'project', + 'group', + 'project-member', + } for table in TABLES_ORDERED_BY_FK_DEPS: if table in ignore: continue diff --git a/web/package.json b/web/package.json index 88f97a346..8451098af 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "metamist", - "version": "7.1.1", + "version": "7.2.0", "private": true, "dependencies": { "@apollo/client": "^3.7.3",