Skip to content

Commit

Permalink
Update project based db connections
Browse files Browse the repository at this point in the history
  • Loading branch information
dancoates committed Mar 22, 2024
1 parent e9c5509 commit 8135d4c
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 59 deletions.
12 changes: 6 additions & 6 deletions api/routes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from api.utils.dates import parse_date_only_string
from api.utils.db import (
Connection,
get_project_readonly_connection,
get_project_read_connection,
get_project_write_connection,
get_projectless_db_connection,
)
Expand Down Expand Up @@ -141,7 +141,7 @@ async def update_analysis(
)
async def get_all_sample_ids_without_analysis_type(
analysis_type: str,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""get_all_sample_ids_without_analysis_type"""
atable = AnalysisLayer(connection)
Expand All @@ -161,7 +161,7 @@ async def get_all_sample_ids_without_analysis_type(
operation_id='getIncompleteAnalyses',
)
async def get_incomplete_analyses(
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""Get analyses with status queued or in-progress"""
atable = AnalysisLayer(connection)
Expand All @@ -176,7 +176,7 @@ async def get_incomplete_analyses(
)
async def get_latest_complete_analysis_for_type(
analysis_type: str,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""Get (SINGLE) latest complete analysis for some analysis type"""
alayer = AnalysisLayer(connection)
Expand All @@ -194,7 +194,7 @@ async def get_latest_complete_analysis_for_type(
async def get_latest_complete_analysis_for_type_post(
analysis_type: str,
meta: dict[str, Any] = Body(..., embed=True), # type: ignore[assignment]
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""
Get SINGLE latest complete analysis for some analysis type
Expand Down Expand Up @@ -280,7 +280,7 @@ async def get_analysis_runner_log(
async def get_sample_reads_map(
export_type: ExportType = ExportType.JSON,
sequencing_types: list[str] = Query(None), # type: ignore
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""
Get map of ExternalSampleId pathToCram InternalSeqGroupID for seqr
Expand Down
16 changes: 8 additions & 8 deletions api/routes/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import APIRouter

from api.utils import get_project_readonly_connection
from api.utils import get_project_read_connection
from api.utils.db import Connection, get_projectless_db_connection
from db.python.layers.assay import AssayLayer
from db.python.tables.assay import AssayFilter
Expand Down Expand Up @@ -52,7 +52,7 @@ async def get_assay_by_id(
'/{project}/external_id/{external_id}/details', operation_id='getAssayByExternalId'
)
async def get_assay_by_external_id(
external_id: str, connection=get_project_readonly_connection
external_id: str, connection=get_project_read_connection
):
"""Get an assay by ONE of its external identifiers"""
assay_layer = AssayLayer(connection)
Expand Down Expand Up @@ -86,13 +86,13 @@ async def get_assays_by_criteria(
unwrapped_sample_ids = sample_id_transform_to_raw_list(sample_ids)

filter_ = AssayFilter(
sample_id=GenericFilter(in_=unwrapped_sample_ids)
if unwrapped_sample_ids
else None,
sample_id=(
GenericFilter(in_=unwrapped_sample_ids) if unwrapped_sample_ids else None
),
id=GenericFilter(in_=assay_ids) if assay_ids else None,
external_id=GenericFilter(in_=external_assay_ids)
if external_assay_ids
else None,
external_id=(
GenericFilter(in_=external_assay_ids) if external_assay_ids else None
),
meta=assay_meta,
sample_meta=sample_meta,
project=GenericFilter(in_=pids) if pids else None,
Expand Down
6 changes: 3 additions & 3 deletions api/routes/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from api.utils import get_projectless_db_connection
from api.utils.db import (
Connection,
get_project_readonly_connection,
get_project_read_connection,
get_project_write_connection,
)
from api.utils.export import ExportType
Expand Down Expand Up @@ -73,7 +73,7 @@ async def get_pedigree(
replace_with_family_external_ids: bool = True,
include_header: bool = True,
empty_participant_value: Optional[str] = None,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
include_participants_not_in_families: bool = False,
):
"""
Expand Down Expand Up @@ -140,7 +140,7 @@ async def get_pedigree(
async def get_families(
participant_ids: Optional[List[int]] = Query(None),
sample_ids: Optional[List[str]] = Query(None),
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
) -> List[Family]:
"""Get families for some project"""
family_layer = FamilyLayer(connection)
Expand Down
14 changes: 8 additions & 6 deletions api/routes/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from api.utils import get_projectless_db_connection
from api.utils.db import (
Connection,
get_project_readonly_connection,
get_project_read_connection,
get_project_write_connection,
)
from api.utils.export import ExportType
Expand Down Expand Up @@ -46,7 +46,7 @@ async def get_individual_metadata_template_for_seqr(
external_participant_ids: list[str] | None = Query(default=None), # type: ignore[assignment]
# pylint: disable=invalid-name
replace_with_participant_external_ids: bool = True,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""Get individual metadata template for SEQR as a CSV"""
participant_layer = ParticipantLayer(connection)
Expand Down Expand Up @@ -89,7 +89,7 @@ async def get_individual_metadata_template_for_seqr(
async def get_id_map_by_external_ids(
external_participant_ids: list[str],
allow_missing: bool = False,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""Get ID map of participants, by external_id"""
player = ParticipantLayer(connection)
Expand Down Expand Up @@ -120,7 +120,7 @@ async def get_external_participant_id_to_sequencing_group_id(
sequencing_type: str = None,
export_type: ExportType = ExportType.JSON,
flip_columns: bool = False,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""
Get csv / tsv export of external_participant_id to sequencing_group_id
Expand Down Expand Up @@ -153,7 +153,9 @@ async def get_external_participant_id_to_sequencing_group_id(
writer.writerows(rows)

ext = export_type.get_extension()
filename = f'{project}-participant-to-sequencing-group-map-{date.today().isoformat()}{ext}'
filename = (
f'{project}-participant-to-sequencing-group-map-{date.today().isoformat()}{ext}'
)
if sequencing_type:
filename = f'{project}-{sequencing_type}-participant-to-sequencing-group-map-{date.today().isoformat()}{ext}'
return StreamingResponse(
Expand Down Expand Up @@ -205,7 +207,7 @@ async def upsert_participants(
async def get_participants(
external_participant_ids: list[str] = None,
internal_participant_ids: list[int] = None,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""Get participants, default ALL participants in project"""
player = ParticipantLayer(connection)
Expand Down
8 changes: 4 additions & 4 deletions api/routes/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from api.utils.db import (
Connection,
get_project_readonly_connection,
get_project_read_connection,
get_project_write_connection,
get_projectless_db_connection,
)
Expand Down Expand Up @@ -61,7 +61,7 @@ async def upsert_samples(
async def get_sample_id_map_by_external(
external_ids: list[str],
allow_missing: bool = False,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""Get map of sample IDs, { [externalId]: internal_sample_id }"""
st = SampleLayer(connection)
Expand Down Expand Up @@ -90,7 +90,7 @@ async def get_sample_id_map_by_internal(
'/{project}/id-map/internal/all', operation_id='getAllSampleIdMapByInternal'
)
async def get_all_sample_id_map_by_internal(
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
):
"""Get map of ALL sample IDs, { [internal_id]: external_sample_id }"""
st = SampleLayer(connection)
Expand All @@ -106,7 +106,7 @@ async def get_all_sample_id_map_by_internal(
operation_id='getSampleByExternalId',
)
async def get_sample_by_external_id(
external_id: str, connection: Connection = get_project_readonly_connection
external_id: str, connection: Connection = get_project_read_connection
):
"""Get sample by external ID"""
st = SampleLayer(connection)
Expand Down
4 changes: 2 additions & 2 deletions api/routes/sequencing_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from api.utils.db import (
Connection,
get_project_readonly_connection,
get_project_read_connection,
get_project_write_connection,
get_projectless_db_connection,
)
Expand Down Expand Up @@ -42,7 +42,7 @@ async def get_sequencing_group(

@router.get('/project/{project}', operation_id='getAllSequencingGroupIdsBySampleByType')
async def get_all_sequencing_group_ids_by_sample_by_type(
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
) -> dict[str, dict[str, list[str]]]:
"""Creates a new sample, and returns the internal sample ID"""
st = SequencingGroupLayer(connection)
Expand Down
15 changes: 8 additions & 7 deletions api/routes/web.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
Web routes
"""

from typing import Optional

from fastapi import APIRouter, Request
from pydantic import BaseModel

from api.utils.db import (
Connection,
get_project_readonly_connection,
get_project_read_connection,
get_project_write_connection,
get_projectless_db_connection,
)
Expand Down Expand Up @@ -39,7 +40,7 @@ async def get_project_summary(
grid_filter: list[SearchItem],
limit: int = 20,
token: Optional[int] = 0,
connection: Connection = get_project_readonly_connection,
connection: Connection = get_project_read_connection,
) -> ProjectSummary:
"""Creates a new sample, and returns the internal sample ID"""
st = WebLayer(connection)
Expand All @@ -56,11 +57,11 @@ async def get_project_summary(
new_token = max(int(sample.id) for p in participants for sample in p.samples)

links = PagingLinks(
next=str(request.base_url)
+ request.url.path.lstrip('/')
+ f'?token={new_token}'
if new_token
else None,
next=(
str(request.base_url) + request.url.path.lstrip('/') + f'?token={new_token}'
if new_token
else None
),
self=str(request.url),
token=str(new_token) if new_token else None,
)
Expand Down
39 changes: 35 additions & 4 deletions api/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from db.python.connect import Connection, SMConnections
from db.python.gcp_connect import BqConnection, PubSubConnection
from db.python.tables.project import ProjectPermissionsTable
from models.models.group import GroupProjectRole

EXPECTED_AUDIENCE = getenv('SM_OAUTHAUDIENCE')

Expand Down Expand Up @@ -78,23 +79,50 @@ async def dependable_get_write_project_connection(
return await ProjectPermissionsTable.get_project_connection(
project_name=project,
author=author,
readonly=False,
allowed_roles={GroupProjectRole.write},
ar_guid=ar_guid,
on_behalf_of=on_behalf_of,
meta=meta,
)


async def dependable_get_readonly_project_connection(
async def dependable_get_read_project_connection(
project: str,
request: Request,
author: str = Depends(authenticate),
ar_guid: str = Depends(get_ar_guid),
) -> Connection:
"""FastAPI handler for getting connection WITH project"""
meta = {"path": request.url.path}
if request.client:
meta["ip"] = request.client.host
return await ProjectPermissionsTable.get_project_connection(
project_name=project,
author=author,
allowed_roles={
GroupProjectRole.read,
GroupProjectRole.write,
GroupProjectRole.contribute,
},
on_behalf_of=None,
ar_guid=ar_guid,
)


async def dependable_get_contribute_project_connection(
project: str,
request: Request,
author: str = Depends(authenticate),
ar_guid: str = Depends(get_ar_guid),
) -> Connection:
"""FastAPI handler for getting connection WITH project"""
meta = {"path": request.url.path}
if request.client:
meta["ip"] = request.client.host
return await ProjectPermissionsTable.get_project_connection(
project_name=project,
author=author,
readonly=True,
allowed_roles={GroupProjectRole.write, GroupProjectRole.contribute},
on_behalf_of=None,
ar_guid=ar_guid,
)
Expand Down Expand Up @@ -150,7 +178,10 @@ def validate_iap_jwt_and_get_email(iap_jwt, audience):


get_author = Depends(authenticate)
get_project_readonly_connection = Depends(dependable_get_readonly_project_connection)
get_project_read_connection = Depends(dependable_get_read_project_connection)
get_project_contribute_connection = Depends(
dependable_get_contribute_project_connection
)
get_project_write_connection = Depends(dependable_get_write_project_connection)
get_projectless_db_connection = Depends(dependable_get_connection)
get_projectless_bq_connection = Depends(dependable_get_bq_connection)
Expand Down
Loading

0 comments on commit 8135d4c

Please sign in to comment.