Skip to content

Commit

Permalink
Inbuilt groups for permissions (#568)
Browse files Browse the repository at this point in the history
* First pass at an internally managed group members

* Project permissions table cleanup + add tests

* Linting + tests

* Remove bool,ord as bool is a tinyint

* PR clean-up

* Apply treatement to incoming change

* Fix incorrectly applied change

* Fix nullable reference

* Runtime fixes

* Revert "Fix nullable reference"

This reverts commit d72d605.

* Fix nullable reference II

---------

Co-authored-by: Michael Franklin <[email protected]>
  • Loading branch information
illusional and illusional authored Oct 21, 2023
1 parent 040759e commit 8dcd99e
Show file tree
Hide file tree
Showing 18 changed files with 1,035 additions and 619 deletions.
1 change: 0 additions & 1 deletion .github/workflows/deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
CLOUDSDK_CORE_DISABLE_PROMPTS: 1
# used for generating API
SM_DOCKER: australia-southeast1-docker.pkg.dev/sample-metadata/images/server:${{ github.sha }}
SM_API_DOCKER: australia-southeast1-docker.pkg.dev/cpg-common/images/sm-api
defaults:
run:
shell: bash -eo pipefail -l {0}
Expand Down
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ repos:
rev: v1.5.1
hooks:
- id: mypy
args:
[
args: [
--pretty,
--show-error-codes,
--no-strict-optional,
--ignore-missing-imports,
--install-types,
--non-interactive,
--show-error-context,
# --show-error-context,
--check-untyped-defs,
--explicit-package-bases,
--disable-error-code,
Expand Down
4 changes: 3 additions & 1 deletion api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ async def load_projects_for_ids(project_ids: list[int], connection) -> list[Proj
Get projects by IDs
"""
pttable = ProjectPermissionsTable(connection.connection)
projects = await pttable.get_projects_by_ids(project_ids)
projects = await pttable.get_and_check_access_to_projects_for_ids(
user=connection.user, project_ids=project_ids, readonly=True
)
p_by_id = {p.id: p for p in projects}
return [p_by_id.get(p) for p in project_ids]

Expand Down
31 changes: 15 additions & 16 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ class GraphQLProject:
name: str
dataset: str
meta: strawberry.scalars.JSON
read_group_name: str | None = None
write_group_name: str | None = None

@staticmethod
def from_internal(internal: Project) -> 'GraphQLProject':
Expand All @@ -82,8 +80,6 @@ def from_internal(internal: Project) -> 'GraphQLProject':
name=internal.name,
dataset=internal.dataset,
meta=internal.meta,
read_group_name=internal.read_group_name,
write_group_name=internal.write_group_name,
)

@strawberry.field()
Expand Down Expand Up @@ -492,9 +488,11 @@ async def analyses(
if project:
ptable = ProjectPermissionsTable(connection.connection)
project_ids = project.all_values()
project_id_map = await ptable.get_project_id_map_for_names(
author=connection.author, project_names=project_ids, readonly=True
projects = await ptable.get_and_check_access_to_projects_for_names(
user=connection.author, project_names=project_ids, readonly=True
)
project_id_map = {p.name: p.id for p in projects}

analyses = await loader.load(
{
'id': root.internal_id,
Expand Down Expand Up @@ -562,11 +560,10 @@ def enum(self, info: Info) -> GraphQLEnum:
async def project(self, info: Info, name: str) -> GraphQLProject:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
project_id = await ptable.get_project_id_from_name_and_user(
project = await ptable.get_and_check_access_to_project_for_name(
user=connection.author, project_name=name, readonly=True
)
presponse = await ptable.get_projects_by_ids([project_id])
return GraphQLProject.from_internal(presponse[0])
return GraphQLProject.from_internal(project)

@strawberry.field
async def sample(
Expand All @@ -590,11 +587,13 @@ async def sample(
if external_id and not project:
raise ValueError('Must provide project when using external_id filter')

project_name_map: dict[str, int] = {}
if project:
project_ids = project.all_values()
project_id_map = await ptable.get_project_id_map_for_names(
author=connection.author, project_names=project_ids, readonly=True
projects = await ptable.get_and_check_access_to_projects_for_ids(
user=connection.author, project_ids=project_ids, readonly=True
)
project_name_map = {p.name: p.id for p in projects}

filter_ = SampleFilter(
id=id.to_internal_filter(sample_id_transform_to_raw) if id else None,
Expand All @@ -604,7 +603,7 @@ async def sample(
participant_id=participant_id.to_internal_filter()
if participant_id
else None,
project=project.to_internal_filter(lambda pname: project_id_map[pname])
project=project.to_internal_filter(lambda pname: project_name_map[pname])
if project
else None,
active=active.to_internal_filter() if active else GenericFilter(eq=True),
Expand Down Expand Up @@ -635,9 +634,10 @@ async def sequencing_groups(
project_id_map = {}
if project:
project_ids = project.all_values()
project_id_map = await ptable.get_project_id_map_for_names(
author=connection.author, project_names=project_ids, readonly=True
projects = await ptable.get_and_check_access_to_projects_for_ids(
user=connection.author, project_ids=project_ids, readonly=True
)
project_id_map = {p.name: p.id for p in projects}

filter_ = SequencingGroupFilter(
project=project.to_internal_filter(lambda val: project_id_map[val])
Expand Down Expand Up @@ -681,10 +681,9 @@ async def family(self, info: Info, family_id: int) -> GraphQLFamily:
async def my_projects(self, info: Info) -> list[GraphQLProject]:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
project_map = await ptable.get_projects_accessible_by_user(
projects = await ptable.get_projects_accessible_by_user(
connection.author, readonly=True
)
projects = await ptable.get_projects_by_ids(list(project_map.keys()))
return [GraphQLProject.from_internal(p) for p in projects]


Expand Down
5 changes: 3 additions & 2 deletions api/routes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,10 @@ async def query_analyses(
raise ValueError('Must specify "projects"')

pt = ProjectPermissionsTable(connection=connection.connection)
project_name_map = await pt.get_project_id_map_for_names(
author=connection.author, project_names=query.projects, readonly=True
projects = await pt.get_and_check_access_to_projects_for_names(
user=connection.author, project_names=query.projects, readonly=True
)
project_name_map = {p.name: p.id for p in projects}
atable = AnalysisLayer(connection)
analyses = await atable.query(query.to_filter(project_name_map))
return [a.to_external() for a in analyses]
Expand Down
56 changes: 36 additions & 20 deletions api/routes/project.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List

from fastapi import APIRouter

Expand All @@ -13,48 +13,43 @@
async def get_all_projects(connection=get_projectless_db_connection):
"""Get list of projects"""
ptable = ProjectPermissionsTable(connection.connection)
return await ptable.get_project_rows(
author=connection.author, check_permissions=False
)
return await ptable.get_all_projects(author=connection.author)


@router.get('/', operation_id='getMyProjects', response_model=List[str])
async def get_my_projects(connection=get_projectless_db_connection):
"""Get projects I have access to"""
ptable = ProjectPermissionsTable(connection.connection)
pmap = await ptable.get_projects_accessible_by_user(
projects = await ptable.get_projects_accessible_by_user(
author=connection.author, readonly=True
)
return list(pmap.values())
return [p.name for p in projects]


@router.put('/', operation_id='createProject')
async def create_project(
name: str,
dataset: str,
read_group_name: Optional[str] = None,
write_group_name: Optional[str] = None,
create_test_project: bool = True,
create_test_project: bool = False,
connection: Connection = get_projectless_db_connection,
) -> int:
"""
Create a new project
"""
if not read_group_name:
read_group_name = f'{dataset}-sample-metadata-main-read'
if not write_group_name:
write_group_name = f'{dataset}-sample-metadata-main-write'

ptable = ProjectPermissionsTable(connection.connection)
pid = await ptable.create_project(
project_name=name,
dataset_name=dataset,
read_group_name=read_group_name,
write_group_name=write_group_name,
create_test_project=create_test_project,
author=connection.author,
)

if create_test_project:
await ptable.create_project(
project_name=name + '-test',
dataset_name=dataset,
author=connection.author,
)

return pid


Expand Down Expand Up @@ -90,11 +85,32 @@ async def delete_project_data(
Requires READ access + project-creator permissions
"""
ptable = ProjectPermissionsTable(connection.connection)
pid = await ptable.get_project_id_from_name_and_user(
connection.author, project, readonly=False
p_obj = await ptable.get_and_check_access_to_project_for_name(
user=connection.author, project_name=project, readonly=False
)
success = await ptable.delete_project_data(
project_id=pid, delete_project=delete_project, author=connection.author
project_id=p_obj.id, delete_project=delete_project, author=connection.author
)

return {'success': success}


@router.patch('/{project}/members', operation_id='updateProjectMembers')
async def update_project_members(
project: str,
members: list[str],
readonly: bool,
connection: Connection = get_projectless_db_connection,
):
"""
Update project members for specific read / write group.
Not that this is protected by access to a specific access group
"""
ptable = ProjectPermissionsTable(connection.connection)
await ptable.set_group_members(
group_name=ptable.get_project_group_name(project, readonly=readonly),
members=members,
author=connection.author,
)

return {'success': True}
14 changes: 8 additions & 6 deletions api/routes/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
from pydantic import BaseModel

from api.utils.db import (
Connection,
get_project_readonly_connection,
get_projectless_db_connection,
get_project_write_connection,
Connection,
get_projectless_db_connection,
)
from db.python.layers.search import SearchLayer
from db.python.layers.seqr import SeqrLayer
from db.python.layers.web import WebLayer, SearchItem
from db.python.layers.web import SearchItem, WebLayer
from db.python.tables.project import ProjectPermissionsTable
from models.models.search import SearchResponse
from models.models.web import ProjectSummary, PagingLinks
from models.models.web import PagingLinks, ProjectSummary


class SearchResponseModel(BaseModel):
Expand Down Expand Up @@ -81,7 +81,7 @@ async def search_by_keyword(keyword: str, connection=get_projectless_db_connecti
projects = await pt.get_projects_accessible_by_user(
connection.author, readonly=True
)
project_ids = list(projects.keys())
project_ids = [p.id for p in projects]
responses = await SearchLayer(connection).search(keyword, project_ids=project_ids)

for res in responses:
Expand All @@ -90,7 +90,9 @@ async def search_by_keyword(keyword: str, connection=get_projectless_db_connecti
return SearchResponseModel(responses=responses)


@router.post('/{project}/{sequencing_type}/sync-dataset', operation_id='syncSeqrProject')
@router.post(
'/{project}/{sequencing_type}/sync-dataset', operation_id='syncSeqrProject'
)
async def sync_seqr_project(
sequencing_type: str,
sync_families: bool = True,
Expand Down
Loading

0 comments on commit 8dcd99e

Please sign in to comment.