Skip to content

Commit

Permalink
Update routes, layers and table files to work with new permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
dancoates committed Jun 14, 2024
1 parent e859d07 commit e30590f
Show file tree
Hide file tree
Showing 33 changed files with 491 additions and 575 deletions.
80 changes: 34 additions & 46 deletions api/routes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,16 @@
from api.utils.dates import parse_date_only_string
from api.utils.db import (
Connection,
get_project_read_connection,
get_project_write_connection,
get_project_db_connection,
get_projectless_db_connection,
)
from api.utils.export import ExportType
from db.python.layers.analysis import AnalysisLayer
from db.python.tables.analysis import AnalysisFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.utils import GenericFilter
from models.enums import AnalysisStatus
from models.models.analysis import (
Analysis,
AnalysisInternal,
ProportionalDateTemporalMethod,
)
from models.models.group import ReadAccessRoles
from models.models.analysis import Analysis, ProportionalDateTemporalMethod
from models.models.project import FullWriteAccessRoles, ReadAccessRoles
from models.utils.sequencing_group_id_format import (
sequencing_group_id_format,
sequencing_group_id_format_list,
Expand Down Expand Up @@ -96,7 +90,8 @@ def to_filter(self, project_id_map: dict[str, int]) -> AnalysisFilter:

@router.put('/{project}/', operation_id='createAnalysis', response_model=int)
async def create_analysis(
analysis: Analysis, connection: Connection = get_project_write_connection
analysis: Analysis,
connection: Connection = get_project_db_connection(FullWriteAccessRoles),
) -> int:
"""Create a new analysis"""

Expand Down Expand Up @@ -136,14 +131,14 @@ async def update_analysis(
)
async def get_all_sample_ids_without_analysis_type(
analysis_type: str,
connection: Connection = get_project_read_connection,
connection: Connection = get_project_db_connection(ReadAccessRoles),
):
"""get_all_sample_ids_without_analysis_type"""
atable = AnalysisLayer(connection)
assert connection.project
assert connection.project_id
sequencing_group_ids = (
await atable.get_all_sequencing_group_ids_without_analysis_type(
connection.project, analysis_type
connection.project_id, analysis_type
)
)
return {
Expand All @@ -156,12 +151,12 @@ async def get_all_sample_ids_without_analysis_type(
operation_id='getIncompleteAnalyses',
)
async def get_incomplete_analyses(
connection: Connection = get_project_read_connection,
connection: Connection = get_project_db_connection(ReadAccessRoles),
):
"""Get analyses with status queued or in-progress"""
atable = AnalysisLayer(connection)
assert connection.project
results = await atable.get_incomplete_analyses(project=connection.project)
assert connection.project_id
results = await atable.get_incomplete_analyses(project=connection.project_id)
return [r.to_external() for r in results]


Expand All @@ -171,13 +166,13 @@ async def get_incomplete_analyses(
)
async def get_latest_complete_analysis_for_type(
analysis_type: str,
connection: Connection = get_project_read_connection,
connection: Connection = get_project_db_connection(ReadAccessRoles),
):
"""Get (SINGLE) latest complete analysis for some analysis type"""
alayer = AnalysisLayer(connection)
assert connection.project
assert connection.project_id
analysis = await alayer.get_latest_complete_analysis_for_type(
project=connection.project, analysis_type=analysis_type
project=connection.project_id, analysis_type=analysis_type
)
return analysis.to_external()

Expand All @@ -189,16 +184,16 @@ async def get_latest_complete_analysis_for_type(
async def get_latest_complete_analysis_for_type_post(
analysis_type: str,
meta: dict[str, Any] = Body(..., embed=True), # type: ignore[assignment]
connection: Connection = get_project_read_connection,
connection: Connection = get_project_db_connection(ReadAccessRoles),
):
"""
Get SINGLE latest complete analysis for some analysis type
(you can specify meta attributes in this route)
"""
alayer = AnalysisLayer(connection)
assert connection.project
assert connection.project_id
analysis = await alayer.get_latest_complete_analysis_for_type(
project=connection.project,
project=connection.project_id,
analysis_type=analysis_type,
meta=meta,
)
Expand Down Expand Up @@ -229,11 +224,8 @@ async def query_analyses(
if not query.projects:
raise ValueError('Must specify "projects"')

pt = ProjectPermissionsTable(connection)
projects = await pt.get_and_check_access_to_projects_for_names(
user=connection.author,
project_names=query.projects,
allowed_roles=ReadAccessRoles,
projects = connection.get_and_check_access_to_projects_for_names(
query.projects, ReadAccessRoles
)
project_name_map = {p.name: p.id for p in projects}
atable = AnalysisLayer(connection)
Expand All @@ -243,27 +235,24 @@ async def query_analyses(

@router.get('/analysis-runner', operation_id='getAnalysisRunnerLog')
async def get_analysis_runner_log(
project_names: list[str] = Query(None), # type: ignore
# author: str = None, # not implemented yet, uncomment when we do
output_dir: str = None,
ar_guid: str = None,
project_names: list[str],
output_dir: str,
ar_guid: str | None = None,
connection: Connection = get_projectless_db_connection,
) -> list[AnalysisInternal]:
) -> list[Analysis]:
"""
Get log for the analysis-runner, useful for checking this history of analysis
"""
atable = AnalysisLayer(connection)
project_ids = None
if project_names:
pt = ProjectPermissionsTable(connection)
projects = await pt.get_and_check_access_to_projects_for_names(
connection.author, project_names, allowed_roles=ReadAccessRoles
)
project_ids = [p.id for p in projects]

projects = connection.get_and_check_access_to_projects_for_names(
project_names, allowed_roles=ReadAccessRoles
)
project_ids = [p.id for p in projects]

results = await atable.get_analysis_runner_log(
project_ids=project_ids,
# author=author,
output_dir=output_dir,
ar_guid=ar_guid,
)
Expand All @@ -278,7 +267,7 @@ async def get_analysis_runner_log(
async def get_sample_reads_map(
export_type: ExportType = ExportType.JSON,
sequencing_types: list[str] = Query(None), # type: ignore
connection: Connection = get_project_read_connection,
connection: Connection = get_project_db_connection(ReadAccessRoles),
):
"""
Get map of ExternalSampleId pathToCram InternalSeqGroupID for seqr
Expand All @@ -293,9 +282,9 @@ async def get_sample_reads_map(
"""

at = AnalysisLayer(connection)
assert connection.project
assert connection.project_id
objs = await at.get_sample_cram_path_map_for_seqr(
project=connection.project, sequencing_types=sequencing_types
project=connection.project_id, sequencing_types=sequencing_types
)

for r in objs:
Expand All @@ -311,7 +300,7 @@ async def get_sample_reads_map(
writer = csv.writer(output, delimiter=export_type.get_delimiter())
writer.writerows(rows)

basefn = f'{connection.project}-seqr-igv-paths-{date.today().isoformat()}'
basefn = f'{connection.project_id}-seqr-igv-paths-{date.today().isoformat()}'

return StreamingResponse(
iter([output.getvalue()]),
Expand Down Expand Up @@ -345,9 +334,8 @@ async def get_proportionate_map(
}
}
"""
pt = ProjectPermissionsTable(connection)
project_list = await pt.get_and_check_access_to_projects_for_names(
connection.author, projects, allowed_roles=ReadAccessRoles
project_list = connection.get_and_check_access_to_projects_for_names(
projects, allowed_roles=ReadAccessRoles
)
project_ids = [p.id for p in project_list]

Expand Down
21 changes: 9 additions & 12 deletions api/routes/analysis_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@

from fastapi import APIRouter

from api.utils.db import (
Connection,
get_project_read_connection,
get_project_write_connection,
)
from api.utils.db import Connection, get_project_db_connection
from db.python.layers.analysis_runner import AnalysisRunnerLayer
from db.python.tables.analysis_runner import AnalysisRunnerFilter
from db.python.utils import GenericFilter
from models.models.analysis_runner import AnalysisRunner, AnalysisRunnerInternal
from models.models.project import FullWriteAccessRoles, ReadAccessRoles

router = APIRouter(prefix='/analysis-runner', tags=['analysis-runner'])

Expand All @@ -32,13 +29,13 @@ async def create_analysis_runner_log( # pylint: disable=too-many-arguments
output_path: str,
hail_version: str | None = None,
cwd: str | None = None,
connection: Connection = get_project_write_connection,
connection: Connection = get_project_db_connection(FullWriteAccessRoles),
) -> str:
"""Create a new analysis runner log"""

alayer = AnalysisRunnerLayer(connection)

if not connection.project:
if not connection.project_id:
raise ValueError('Project not set')

analysis_id = await alayer.insert_analysis_runner_entry(
Expand All @@ -58,7 +55,7 @@ async def create_analysis_runner_log( # pylint: disable=too-many-arguments
batch_url=batch_url,
submitting_user=submitting_user,
meta=meta,
project=connection.project,
project=connection.project_id,
audit_log_id=None,
output_path=output_path,
)
Expand All @@ -75,13 +72,13 @@ async def get_analysis_runner_logs(
repository: str | None = None,
access_level: str | None = None,
environment: str | None = None,
connection: Connection = get_project_read_connection,
connection: Connection = get_project_db_connection(ReadAccessRoles),
) -> list[AnalysisRunner]:
"""Get analysis runner logs"""

atable = AnalysisRunnerLayer(connection)

if not connection.project:
if not connection.project_id:
raise ValueError('Project not set')

filter_ = AnalysisRunnerFilter(
Expand All @@ -90,9 +87,9 @@ async def get_analysis_runner_logs(
repository=GenericFilter(eq=repository),
access_level=GenericFilter(eq=access_level),
environment=GenericFilter(eq=environment),
project=GenericFilter(eq=connection.project),
project=GenericFilter(eq=connection.project_id),
)

logs = await atable.query(filter_)

return [log.to_external({connection.project: project}) for log in logs]
return [log.to_external({connection.project_id: project}) for log in logs]
13 changes: 6 additions & 7 deletions api/routes/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

from api.utils.db import (
Connection,
get_project_read_connection,
get_project_db_connection,
get_projectless_db_connection,
)
from db.python.layers.assay import AssayLayer
from db.python.tables.assay import AssayFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.utils import GenericFilter
from models.base import SMBase
from models.models.assay import AssayUpsert
from models.models.group import ReadAccessRoles
from models.models.project import ReadAccessRoles
from models.utils.sample_id_format import sample_id_transform_to_raw_list

router = APIRouter(prefix='/assay', tags=['assay'])
Expand Down Expand Up @@ -57,7 +56,8 @@ async def get_assay_by_id(
'/{project}/external_id/{external_id}/details', operation_id='getAssayByExternalId'
)
async def get_assay_by_external_id(
external_id: str, connection=get_project_read_connection
external_id: str,
connection: Connection = get_project_db_connection(ReadAccessRoles),
):
"""Get an assay by ONE of its external identifiers"""
assay_layer = AssayLayer(connection)
Expand All @@ -84,12 +84,11 @@ async def get_assays_by_criteria(
):
"""Get assays by criteria"""
assay_layer = AssayLayer(connection)
pt = ProjectPermissionsTable(connection)

pids: list[int] | None = None
if criteria.projects:
project_list = await pt.get_and_check_access_to_projects_for_names(
connection.author, criteria.projects, allowed_roles=ReadAccessRoles
project_list = connection.get_and_check_access_to_projects_for_names(
criteria.projects, allowed_roles=ReadAccessRoles
)
pids = [p.id for p in project_list]

Expand Down
Loading

0 comments on commit e30590f

Please sign in to comment.