Skip to content

Commit

Permalink
Merging dev changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
milo-hyben committed Jan 12, 2024
2 parents f6465e2 + b4cd5de commit 12e40d1
Show file tree
Hide file tree
Showing 70 changed files with 1,671 additions and 458 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 6.5.0
current_version = 6.6.2
commit = True
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>[A-z0-9-]+)
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
exclude: '\.*conda/.*'
Expand All @@ -14,13 +14,13 @@ repos:
- id: check-added-large-files

- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.33.0
rev: v0.38.0
hooks:
- id: markdownlint
args: ["--config", ".markdownlint.json"]

- repo: https://github.com/ambv/black
rev: 23.3.0
rev: 23.12.1
hooks:
- id: black
args: [.]
Expand All @@ -29,7 +29,7 @@ repos:
exclude: ^metamist/

- repo: https://github.com/PyCQA/flake8
rev: "6.0.0"
rev: "6.1.0"
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear, flake8-quotes]
Expand All @@ -45,7 +45,7 @@ repos:

# mypy
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.8.0
hooks:
- id: mypy
args: [
Expand Down
36 changes: 33 additions & 3 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from strawberry.dataloader import DataLoader

from api.utils import get_projectless_db_connection, group_by
from db.python.connect import NotFoundError
from db.python.layers import (
AnalysisLayer,
AssayLayer,
AuditLogLayer,
FamilyLayer,
ParticipantLayer,
SampleLayer,
Expand All @@ -24,16 +24,18 @@
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import GenericFilter, ProjectId
from db.python.utils import GenericFilter, NotFoundError
from models.models import (
AnalysisInternal,
AssayInternal,
FamilyInternal,
ParticipantInternal,
Project,
ProjectId,
SampleInternal,
SequencingGroupInternal,
)
from models.models.audit_log import AuditLogInternal


class LoaderKeys(enum.Enum):
Expand All @@ -44,6 +46,9 @@ class LoaderKeys(enum.Enum):

PROJECTS_FOR_IDS = 'projects_for_id'

AUDIT_LOGS_BY_IDS = 'audit_logs_by_ids'
AUDIT_LOGS_BY_ANALYSIS_IDS = 'audit_logs_by_analysis_ids'

ANALYSES_FOR_SEQUENCING_GROUPS = 'analyses_for_sequencing_groups'

ASSAYS_FOR_SAMPLES = 'sequences_for_samples'
Expand Down Expand Up @@ -168,6 +173,31 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]:
return connected_data_loader_caller


@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS)
async def load_audit_logs_by_ids(
audit_log_ids: list[int], connection
) -> list[AuditLogInternal | None]:
"""
DataLoader: get_audit_logs_by_ids
"""
alayer = AuditLogLayer(connection)
logs = await alayer.get_for_ids(audit_log_ids)
logs_by_id = {log.id: log for log in logs}
return [logs_by_id.get(a) for a in audit_log_ids]


@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS)
async def load_audit_logs_by_analysis_ids(
analysis_ids: list[int], connection
) -> list[list[AuditLogInternal]]:
"""
DataLoader: get_audit_logs_by_analysis_ids
"""
alayer = AnalysisLayer(connection)
logs = await alayer.get_audit_logs_by_analysis_ids(analysis_ids)
return [logs.get(a) or [] for a in analysis_ids]


@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
Expand Down Expand Up @@ -332,7 +362,7 @@ async def load_projects_for_ids(project_ids: list[int], connection) -> list[Proj
"""
Get projects by IDs
"""
pttable = ProjectPermissionsTable(connection.connection)
pttable = ProjectPermissionsTable(connection)
projects = await pttable.get_and_check_access_to_projects_for_ids(
user=connection.user, project_ids=project_ids, readonly=True
)
Expand Down
66 changes: 53 additions & 13 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from models.models import (
AnalysisInternal,
AssayInternal,
AuditLogInternal,
FamilyInternal,
ParticipantInternal,
Project,
Expand Down Expand Up @@ -204,6 +205,29 @@ async def analyses(
return [GraphQLAnalysis.from_internal(a) for a in internal_analysis]


@strawberry.type
class GraphQLAuditLog:
"""AuditLog GraphQL model"""

id: int
author: str
timestamp: datetime.datetime
ar_guid: str | None
comment: str | None
meta: strawberry.scalars.JSON

@staticmethod
def from_internal(audit_log: AuditLogInternal) -> 'GraphQLAuditLog':
return GraphQLAuditLog(
id=audit_log.id,
author=audit_log.author,
timestamp=audit_log.timestamp,
ar_guid=audit_log.ar_guid,
comment=audit_log.comment,
meta=audit_log.meta,
)


@strawberry.type
class GraphQLAnalysis:
"""Analysis GraphQL model"""
Expand All @@ -215,7 +239,6 @@ class GraphQLAnalysis:
timestamp_completed: datetime.datetime | None = None
active: bool
meta: strawberry.scalars.JSON
author: str

@staticmethod
def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis':
Expand All @@ -227,7 +250,6 @@ def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis':
timestamp_completed=internal.timestamp_completed,
active=internal.active,
meta=internal.meta,
author=internal.author,
)

@strawberry.field
Expand All @@ -244,6 +266,14 @@ async def project(self, info: Info, root: 'GraphQLAnalysis') -> GraphQLProject:
project = await loader.load(root.project)
return GraphQLProject.from_internal(project)

@strawberry.field
async def audit_logs(
self, info: Info, root: 'GraphQLAnalysis'
) -> list[GraphQLAuditLog]:
loader = info.context[LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS]
audit_logs = await loader.load(root.id)
return [GraphQLAuditLog.from_internal(audit_log) for audit_log in audit_logs]


@strawberry.type
class GraphQLFamily:
Expand Down Expand Up @@ -297,6 +327,7 @@ class GraphQLParticipant:
karyotype: str | None

project_id: strawberry.Private[int]
audit_log_id: strawberry.Private[int | None]

@staticmethod
def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant':
Expand All @@ -308,6 +339,7 @@ def from_internal(internal: ParticipantInternal) -> 'GraphQLParticipant':
reported_gender=internal.reported_gender,
karyotype=internal.karyotype,
project_id=internal.project,
audit_log_id=internal.audit_log_id,
)

@strawberry.field
Expand Down Expand Up @@ -349,6 +381,16 @@ async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProjec
project = await loader.load(root.project_id)
return GraphQLProject.from_internal(project)

@strawberry.field
async def audit_log(
self, info: Info, root: 'GraphQLParticipant'
) -> GraphQLAuditLog | None:
if root.audit_log_id is None:
return None
loader = info.context[LoaderKeys.AUDIT_LOGS_BY_IDS]
audit_log = await loader.load(root.audit_log_id)
return GraphQLAuditLog.from_internal(audit_log)


@strawberry.type
class GraphQLSample:
Expand All @@ -359,22 +401,20 @@ class GraphQLSample:
active: bool
meta: strawberry.scalars.JSON
type: str
author: str | None

# keep as integers, because they're useful to reference in the fields below
internal_id: strawberry.Private[int]
participant_id: strawberry.Private[int]
project_id: strawberry.Private[int]

@staticmethod
def from_internal(sample: SampleInternal):
def from_internal(sample: SampleInternal) -> 'GraphQLSample':
return GraphQLSample(
id=sample_id_format(sample.id),
external_id=sample.external_id,
active=sample.active,
meta=sample.meta,
type=sample.type,
author=sample.author,
# internals
internal_id=sample.id,
participant_id=sample.participant_id,
Expand Down Expand Up @@ -450,7 +490,7 @@ class GraphQLSequencingGroup:
technology: str
platform: str
meta: strawberry.scalars.JSON
external_ids: strawberry.scalars.JSON | None
external_ids: strawberry.scalars.JSON

internal_id: strawberry.Private[int]
sample_id: strawberry.Private[int]
Expand All @@ -464,7 +504,7 @@ def from_internal(internal: SequencingGroupInternal) -> 'GraphQLSequencingGroup'
technology=internal.technology,
platform=internal.platform,
meta=internal.meta,
external_ids=internal.external_ids,
external_ids=internal.external_ids or {},
# internal
internal_id=internal.id,
sample_id=internal.sample_id,
Expand All @@ -491,7 +531,7 @@ async def analyses(
loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS]
project_id_map = {}
if project:
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
project_ids = project.all_values()
projects = await ptable.get_and_check_access_to_projects_for_names(
user=connection.author, project_names=project_ids, readonly=True
Expand Down Expand Up @@ -541,7 +581,7 @@ def from_internal(internal: AssayInternal) -> 'GraphQLAssay':
id=internal.id,
type=internal.type,
meta=internal.meta,
external_ids=internal.external_ids,
external_ids=internal.external_ids or {},
# internal
sample_id=internal.sample_id,
)
Expand All @@ -564,7 +604,7 @@ def enum(self, info: Info) -> GraphQLEnum:
@strawberry.field()
async def project(self, info: Info, name: str) -> GraphQLProject:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
project = await ptable.get_and_check_access_to_project_for_name(
user=connection.author, project_name=name, readonly=True
)
Expand All @@ -583,7 +623,7 @@ async def sample(
active: GraphQLFilter[bool] | None = None,
) -> list[GraphQLSample]:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
slayer = SampleLayer(connection)

if not id and not project:
Expand Down Expand Up @@ -631,7 +671,7 @@ async def sequencing_groups(
) -> list[GraphQLSequencingGroup]:
connection = info.context['connection']
sglayer = SequencingGroupLayer(connection)
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
if not (project or sample_id or id):
raise ValueError('Must filter by project, sample or id')

Expand Down Expand Up @@ -685,7 +725,7 @@ async def family(self, info: Info, family_id: int) -> GraphQLFamily:
@strawberry.field
async def my_projects(self, info: Info) -> list[GraphQLProject]:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection.connection)
ptable = ProjectPermissionsTable(connection)
projects = await ptable.get_projects_accessible_by_user(
connection.author, readonly=True
)
Expand Down
12 changes: 7 additions & 5 deletions api/routes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ async def create_analysis(

atable = AnalysisLayer(connection)

if analysis.author:
# special tracking here, if we can't catch it through the header
connection.on_behalf_of = analysis.author

analysis_id = await atable.create_analysis(
analysis.to_internal(),
# analysis-runner: usage is tracked through `on_behalf_of`
author=analysis.author,
)

return analysis_id
Expand Down Expand Up @@ -226,7 +228,7 @@ async def query_analyses(
if not query.projects:
raise ValueError('Must specify "projects"')

pt = ProjectPermissionsTable(connection=connection.connection)
pt = ProjectPermissionsTable(connection)
projects = await pt.get_and_check_access_to_projects_for_names(
user=connection.author, project_names=query.projects, readonly=True
)
Expand All @@ -249,7 +251,7 @@ async def get_analysis_runner_log(
atable = AnalysisLayer(connection)
project_ids = None
if project_names:
pt = ProjectPermissionsTable(connection=connection.connection)
pt = ProjectPermissionsTable(connection)
project_ids = await pt.get_project_ids_from_names_and_user(
connection.author, project_names, readonly=True
)
Expand Down Expand Up @@ -335,7 +337,7 @@ async def get_proportionate_map(
}
}
"""
pt = ProjectPermissionsTable(connection=connection.connection)
pt = ProjectPermissionsTable(connection)
project_ids = await pt.get_project_ids_from_names_and_user(
connection.author, projects, readonly=True
)
Expand Down
2 changes: 1 addition & 1 deletion api/routes/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def get_assays_by_criteria(
):
"""Get assays by criteria"""
assay_layer = AssayLayer(connection)
pt = ProjectPermissionsTable(connection.connection)
pt = ProjectPermissionsTable(connection)

pids: list[int] | None = None
if projects:
Expand Down
Loading

0 comments on commit 12e40d1

Please sign in to comment.