Skip to content

Commit

Permalink
Merge pull request #639 from populationgenomics/dev
Browse files Browse the repository at this point in the history
Release
  • Loading branch information
illusional authored Jan 3, 2024
2 parents 5beb48d + f6c226d commit 2a49cd9
Show file tree
Hide file tree
Showing 52 changed files with 1,352 additions and 378 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
exclude: '\.*conda/.*'
Expand All @@ -14,13 +14,13 @@ repos:
- id: check-added-large-files

- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.33.0
rev: v0.38.0
hooks:
- id: markdownlint
args: ["--config", ".markdownlint.json"]

- repo: https://github.com/ambv/black
rev: 23.3.0
rev: 23.12.1
hooks:
- id: black
args: [.]
Expand All @@ -29,7 +29,7 @@ repos:
exclude: ^metamist/

- repo: https://github.com/PyCQA/flake8
rev: "6.0.0"
rev: "6.1.0"
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear, flake8-quotes]
Expand All @@ -45,7 +45,7 @@ repos:

# mypy
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.8.0
hooks:
- id: mypy
args: [
Expand Down
23 changes: 20 additions & 3 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from strawberry.dataloader import DataLoader

from api.utils import get_projectless_db_connection, group_by
from db.python.connect import NotFoundError
from db.python.layers import (
AnalysisLayer,
AssayLayer,
AuditLogLayer,
FamilyLayer,
ParticipantLayer,
SampleLayer,
Expand All @@ -24,16 +24,18 @@
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import GenericFilter, ProjectId
from db.python.utils import GenericFilter, NotFoundError
from models.models import (
AnalysisInternal,
AssayInternal,
FamilyInternal,
ParticipantInternal,
Project,
ProjectId,
SampleInternal,
SequencingGroupInternal,
)
from models.models.audit_log import AuditLogInternal


class LoaderKeys(enum.Enum):
Expand All @@ -44,6 +46,8 @@ class LoaderKeys(enum.Enum):

PROJECTS_FOR_IDS = 'projects_for_id'

AUDIT_LOGS_BY_IDS = 'audit_logs_by_ids'

ANALYSES_FOR_SEQUENCING_GROUPS = 'analyses_for_sequencing_groups'

ASSAYS_FOR_SAMPLES = 'sequences_for_samples'
Expand Down Expand Up @@ -168,6 +172,19 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]:
return connected_data_loader_caller


@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS)
async def load_audit_logs_by_ids(
audit_log_ids: list[int], connection
) -> list[AuditLogInternal | None]:
"""
DataLoader: get_audit_logs_by_ids
"""
alayer = AuditLogLayer(connection)
logs = await alayer.get_for_ids(audit_log_ids)
logs_by_id = {log.id: log for log in logs}
return [logs_by_id.get(a) for a in audit_log_ids]


@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
Expand Down Expand Up @@ -332,7 +349,7 @@ async def load_projects_for_ids(project_ids: list[int], connection) -> list[Proj
"""
Get projects by IDs
"""
pttable = ProjectPermissionsTable(connection.connection)
pttable = ProjectPermissionsTable(connection)
projects = await pttable.get_and_check_access_to_projects_for_ids(
user=connection.user, project_ids=project_ids, readonly=True
)
Expand Down
46 changes: 40 additions & 6 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from models.models import (
AnalysisInternal,
AssayInternal,
AuditLogInternal,
FamilyInternal,
ParticipantInternal,
Project,
Expand Down Expand Up @@ -204,6 +205,27 @@ async def analyses(
return [GraphQLAnalysis.from_internal(a) for a in internal_analysis]


@strawberry.type
class GraphQLAuditLog:
"""AuditLog GraphQL model"""

id: int
author: str
timestamp: datetime.datetime
ar_guid: str | None
comment: str | None

@staticmethod
def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog':
return GraphQLAuditLog(
id=audit_log.id,
author=audit_log.author,
timestamp=audit_log.timestamp,
ar_guid=audit_log.ar_guid,
comment=audit_log.comment,
)


@strawberry.type
class GraphQLAnalysis:
"""Analysis GraphQL model"""
Expand Down Expand Up @@ -297,6 +319,7 @@ class GraphQLParticipant:
karyotype: str | None

project_id: strawberry.Private[int]
audit_log_id: strawberry.Private[int | None]

@staticmethod
def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant':
Expand All @@ -308,6 +331,7 @@ def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant':
reported_gender=internal.reported_gender,
karyotype=internal.karyotype,
project_id=internal.project,
audit_log_id=internal.audit_log_id,
)

@strawberry.field
Expand Down Expand Up @@ -349,6 +373,16 @@ async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProjec
project = await loader.load(root.project_id)
return GraphQLProject.from_internal(project)

@strawberry.field
async def audit_log(
self, info: Info, root: 'GraphQLParticipant'
) -> GraphQLAuditLog | None:
if root.audit_log_id is None:
return None
loader = info.context[LoaderKeys.AUDIT_LOGS_BY_IDS]
audit_log = await loader.load(root.audit_log_id)
return GraphQLAuditLog.from_internal(audit_log)


@strawberry.type
class GraphQLSample:
Expand All @@ -367,7 +401,7 @@ class GraphQLSample:
project_id: strawberry.Private[int]

@staticmethod
def from_internal(sample: SampleInternal):
def from_internal(sample: SampleInternal) -> 'GraphQLSample':
return GraphQLSample(
id=sample_id_format(sample.id),
external_id=sample.external_id,
Expand Down Expand Up @@ -491,7 +525,7 @@ async def analyses(
loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS]
project_id_map = {}
if project:
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
project_ids = project.all_values()
projects = await ptable.get_and_check_access_to_projects_for_names(
user=connection.author, project_names=project_ids, readonly=True
Expand Down Expand Up @@ -564,7 +598,7 @@ def enum(self, info: Info) -> GraphQLEnum:
@strawberry.field()
async def project(self, info: Info, name: str) -> GraphQLProject:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
project = await ptable.get_and_check_access_to_project_for_name(
user=connection.author, project_name=name, readonly=True
)
Expand All @@ -583,7 +617,7 @@ async def sample(
active: GraphQLFilter[bool] | None = None,
) -> list[GraphQLSample]:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
slayer = SampleLayer(connection)

if not id and not project:
Expand Down Expand Up @@ -631,7 +665,7 @@ async def sequencing_groups(
) -> list[GraphQLSequencingGroup]:
connection = info.context['connection']
sglayer = SequencingGroupLayer(connection)
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
if not (project or sample_id or id):
raise ValueError('Must filter by project, sample or id')

Expand Down Expand Up @@ -685,7 +719,7 @@ async def family(self, info: Info, family_id: int) -> GraphQLFamily:
@strawberry.field
async def my_projects(self, info: Info) -> list[GraphQLProject]:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
projects = await ptable.get_projects_accessible_by_user(
connection.author, readonly=True
)
Expand Down
12 changes: 7 additions & 5 deletions api/routes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ async def create_analysis(

atable = AnalysisLayer(connection)

if analysis.author:
# special tracking here, if we can't catch it through the header
connection.on_behalf_of = analysis.author

analysis_id = await atable.create_analysis(
analysis.to_internal(),
# analysis-runner: usage is tracked through `on_behalf_of`
author=analysis.author,
)

return analysis_id
Expand Down Expand Up @@ -226,7 +228,7 @@ async def query_analyses(
if not query.projects:
raise ValueError('Must specify "projects"')

pt = ProjectPermissionsTable(connection=connection.connection)
pt = ProjectPermissionsTable(connection)
projects = await pt.get_and_check_access_to_projects_for_names(
user=connection.author, project_names=query.projects, readonly=True
)
Expand All @@ -249,7 +251,7 @@ async def get_analysis_runner_log(
atable = AnalysisLayer(connection)
project_ids = None
if project_names:
pt = ProjectPermissionsTable(connection=connection.connection)
pt = ProjectPermissionsTable(connection)
project_ids = await pt.get_project_ids_from_names_and_user(
connection.author, project_names, readonly=True
)
Expand Down Expand Up @@ -335,7 +337,7 @@ async def get_proportionate_map(
}
}
"""
pt = ProjectPermissionsTable(connection=connection.connection)
pt = ProjectPermissionsTable(connection)
project_ids = await pt.get_project_ids_from_names_and_user(
connection.author, projects, readonly=True
)
Expand Down
2 changes: 1 addition & 1 deletion api/routes/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def get_assays_by_criteria(
):
"""Get assays by criteria"""
assay_layer = AssayLayer(connection)
pt = ProjectPermissionsTable(connection.connection)
pt = ProjectPermissionsTable(connection)

pids: list[int] | None = None
if projects:
Expand Down
14 changes: 7 additions & 7 deletions api/routes/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
@router.get('/all', operation_id='getAllProjects', response_model=List[Project])
async def get_all_projects(connection=get_projectless_db_connection):
"""Get list of projects"""
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
return await ptable.get_all_projects(author=connection.author)


@router.get('/', operation_id='getMyProjects', response_model=List[str])
async def get_my_projects(connection=get_projectless_db_connection):
"""Get projects I have access to"""
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
projects = await ptable.get_projects_accessible_by_user(
author=connection.author, readonly=True
)
Expand All @@ -36,7 +36,7 @@ async def create_project(
"""
Create a new project
"""
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
pid = await ptable.create_project(
project_name=name,
dataset_name=dataset,
Expand All @@ -56,7 +56,7 @@ async def create_project(
@router.get('/seqr/all', operation_id='getSeqrProjects')
async def get_seqr_projects(connection: Connection = get_projectless_db_connection):
"""Get SM projects that should sync to seqr"""
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
return await ptable.get_seqr_projects()


Expand All @@ -67,7 +67,7 @@ async def update_project(
connection: Connection = get_projectless_db_connection,
):
"""Update a project by project name"""
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
return await ptable.update_project(
project_name=project, update=project_update_model, author=connection.author
)
Expand All @@ -84,7 +84,7 @@ async def delete_project_data(
Can optionally delete the project itself.
Requires READ access + project-creator permissions
"""
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
p_obj = await ptable.get_and_check_access_to_project_for_name(
user=connection.author, project_name=project, readonly=False
)
Expand All @@ -106,7 +106,7 @@ async def update_project_members(
Update project members for specific read / write group.
Not that this is protected by access to a specific access group
"""
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
await ptable.set_group_members(
group_name=ptable.get_project_group_name(project, readonly=readonly),
members=members,
Expand Down
2 changes: 1 addition & 1 deletion api/routes/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def get_samples(
"""
st = SampleLayer(connection)

pt = ProjectPermissionsTable(connection.connection)
pt = ProjectPermissionsTable(connection)
pids: list[int] | None = None
if project_ids:
pids = await pt.get_project_ids_from_names_and_user(
Expand Down
2 changes: 1 addition & 1 deletion api/routes/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def search_by_keyword(keyword: str, connection=get_projectless_db_connecti
that you are a part of (automatically).
"""
# raise ValueError("Test")
pt = ProjectPermissionsTable(connection.connection)
pt = ProjectPermissionsTable(connection)
projects = await pt.get_projects_accessible_by_user(
connection.author, readonly=True
)
Expand Down
Loading

0 comments on commit 2a49cd9

Please sign in to comment.