Skip to content

Commit

Permalink
switch graphql schema to use concrete types for filters
Browse files Browse the repository at this point in the history
This works around an issue in strawberry graphql where using generic
input types on fields that may return many items causes severe
performance degredation.
  • Loading branch information
dancoates committed Jun 24, 2024
1 parent d6c089a commit 71ffd6b
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 64 deletions.
89 changes: 88 additions & 1 deletion api/graphql/filters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
from typing import Callable, Generic, TypeVar

import strawberry

from db.python.utils import GenericFilter, GenericMetaFilter
from models.enums.analysis import AnalysisStatus

T = TypeVar('T')
Y = TypeVar('Y')
Expand All @@ -26,7 +28,7 @@ def all_values(self):
"""
Get all values used anywhere in a filter, useful for getting values to map later
"""
v = []
v: list[T] = []
if self.eq:
v.append(self.eq)
if self.in_:
Expand Down Expand Up @@ -80,6 +82,91 @@ def to_internal_filter_mapped(self, f: Callable[[T], Y]) -> GenericFilter[Y]:
)


GraphQLAnalysisStatus = strawberry.enum(AnalysisStatus)


# The below concrete types are specified individually because there is a performance
# issue in strawberry graphql where the usage of generics in input types causes major
# slowdowns.
# @see https://github.com/strawberry-graphql/strawberry/issues/3544
@strawberry.input(description='Filter for GraphQL queries')
class GraphQLFilterStr(GraphQLFilter[str]):
eq: str | None = None
in_: list[str] | None = None
nin: list[str] | None = None
gt: str | None = None
gte: str | None = None
lt: str | None = None
lte: str | None = None
contains: str | None = None
icontains: str | None = None


@strawberry.input(description='Filter for GraphQL queries')
class GraphQLFilterInt(GraphQLFilter[int]):
eq: int | None = None
in_: list[int] | None = None
nin: list[int] | None = None
gt: int | None = None
gte: int | None = None
lt: int | None = None
lte: int | None = None
contains: int | None = None
icontains: int | None = None


@strawberry.input(description='Filter for GraphQL queries')
class GraphQLFilterBool(GraphQLFilter[bool]):
eq: bool | None = None
in_: list[bool] | None = None
nin: list[bool] | None = None
gt: bool | None = None
gte: bool | None = None
lt: bool | None = None
lte: bool | None = None
contains: bool | None = None
icontains: bool | None = None


@strawberry.input(description='Filter for GraphQL queries')
class GraphQLFilterAnalysisStatus(GraphQLFilter[AnalysisStatus]):
eq: AnalysisStatus | None = None
in_: list[AnalysisStatus] | None = None
nin: list[AnalysisStatus] | None = None
gt: AnalysisStatus | None = None
gte: AnalysisStatus | None = None
lt: AnalysisStatus | None = None
lte: AnalysisStatus | None = None
contains: AnalysisStatus | None = None
icontains: AnalysisStatus | None = None


@strawberry.input(description='Filter for GraphQL queries')
class GraphQLFilterDatetime(GraphQLFilter[datetime.datetime]):
eq: datetime.datetime | None = None
in_: list[datetime.datetime] | None = None
nin: list[datetime.datetime] | None = None
gt: datetime.datetime | None = None
gte: datetime.datetime | None = None
lt: datetime.datetime | None = None
lte: datetime.datetime | None = None
contains: datetime.datetime | None = None
icontains: datetime.datetime | None = None


@strawberry.input(description='Filter for GraphQL queries')
class GraphQLFilterDate(GraphQLFilter[datetime.date]):
eq: datetime.date | None = None
in_: list[datetime.date] | None = None
nin: list[datetime.date] | None = None
gt: datetime.date | None = None
gte: datetime.date | None = None
lt: datetime.date | None = None
lte: datetime.date | None = None
contains: datetime.date | None = None
icontains: datetime.date | None = None


GraphQLMetaFilter = strawberry.scalars.JSON


Expand Down
131 changes: 68 additions & 63 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from strawberry.types import Info

from api.graphql.filters import (
GraphQLFilter,
GraphQLFilterAnalysisStatus,
GraphQLFilterBool,
GraphQLFilterDate,
GraphQLFilterDatetime,
GraphQLFilterInt,
GraphQLFilterStr,
GraphQLMetaFilter,
graphql_meta_filter_to_internal_filter,
)
Expand Down Expand Up @@ -89,8 +94,6 @@ async def m(info: Info) -> list[str]:

GraphQLEnum = strawberry.type(type('GraphQLEnum', (object,), enum_methods))

GraphQLAnalysisStatus = strawberry.enum(AnalysisStatus)


@strawberry.experimental.pydantic.type(model=OurDNALostSample, all_fields=True) # type: ignore
class GraphQLOurDNALostSample:
Expand Down Expand Up @@ -250,11 +253,11 @@ async def analysis_runner(
self,
info: Info,
root: 'Project',
ar_guid: GraphQLFilter[str] | None = None,
author: GraphQLFilter[str] | None = None,
repository: GraphQLFilter[str] | None = None,
access_level: GraphQLFilter[str] | None = None,
environment: GraphQLFilter[str] | None = None,
ar_guid: GraphQLFilterStr | None = None,
author: GraphQLFilterStr | None = None,
repository: GraphQLFilterStr | None = None,
access_level: GraphQLFilterStr | None = None,
environment: GraphQLFilterStr | None = None,
) -> list['GraphQLAnalysisRunner']:
connection = info.context['connection']
alayer = AnalysisRunnerLayer(connection)
Expand Down Expand Up @@ -314,8 +317,8 @@ async def families(
self,
info: Info,
root: 'GraphQLProject',
id: GraphQLFilter[int] | None = None,
external_id: GraphQLFilter[str] | None = None,
id: GraphQLFilterInt | None = None,
external_id: GraphQLFilterStr | None = None,
) -> list['GraphQLFamily']:
# don't need a data loader here as we're presuming we're not often running
# the "families" method for many projects at once. If so, we might need to fix that
Expand All @@ -342,9 +345,9 @@ async def samples(
self,
info: Info,
root: 'GraphQLProject',
type: GraphQLFilter[str] | None = None,
external_id: GraphQLFilter[str] | None = None,
id: GraphQLFilter[str] | None = None,
type: GraphQLFilterStr | None = None,
external_id: GraphQLFilterStr | None = None,
id: GraphQLFilterStr | None = None,
meta: GraphQLMetaFilter | None = None,
) -> list['GraphQLSample']:
loader = info.context[LoaderKeys.SAMPLES_FOR_PROJECTS]
Expand All @@ -362,12 +365,12 @@ async def sequencing_groups(
self,
info: Info,
root: 'GraphQLProject',
id: GraphQLFilter[str] | None = None,
external_id: GraphQLFilter[str] | None = None,
type: GraphQLFilter[str] | None = None,
technology: GraphQLFilter[str] | None = None,
platform: GraphQLFilter[str] | None = None,
active_only: GraphQLFilter[bool] | None = None,
id: GraphQLFilterStr | None = None,
external_id: GraphQLFilterStr | None = None,
type: GraphQLFilterStr | None = None,
technology: GraphQLFilterStr | None = None,
platform: GraphQLFilterStr | None = None,
active_only: GraphQLFilterBool | None = None,
) -> list['GraphQLSequencingGroup']:
loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS]
filter_ = SequencingGroupFilter(
Expand All @@ -394,12 +397,12 @@ async def analyses(
self,
info: Info,
root: 'Project',
type: GraphQLFilter[str] | None = None,
status: GraphQLFilter[GraphQLAnalysisStatus] | None = None,
active: GraphQLFilter[bool] | None = None,
type: GraphQLFilterStr | None = None,
status: GraphQLFilterAnalysisStatus | None = None,
active: GraphQLFilterBool | None = None,
meta: GraphQLMetaFilter | None = None,
timestamp_completed: GraphQLFilter[datetime.datetime] | None = None,
ids: GraphQLFilter[int] | None = None,
timestamp_completed: GraphQLFilterDatetime | None = None,
ids: GraphQLFilterInt | None = None,
) -> list['GraphQLAnalysis']:
connection = info.context['connection']
connection.project = root.id
Expand Down Expand Up @@ -429,11 +432,11 @@ async def cohorts(
self,
info: Info,
root: 'Project',
id: GraphQLFilter[str] | None = None,
name: GraphQLFilter[str] | None = None,
author: GraphQLFilter[str] | None = None,
template_id: GraphQLFilter[str] | None = None,
timestamp: GraphQLFilter[datetime.datetime] | None = None,
id: GraphQLFilterStr | None = None,
name: GraphQLFilterStr | None = None,
author: GraphQLFilterStr | None = None,
template_id: GraphQLFilterStr | None = None,
timestamp: GraphQLFilterDatetime | None = None,
) -> list['GraphQLCohort']:
connection = info.context['connection']
connection.project = root.id
Expand Down Expand Up @@ -651,9 +654,9 @@ async def samples(
self,
info: Info,
root: 'GraphQLParticipant',
type: GraphQLFilter[str] | None = None,
type: GraphQLFilterStr | None = None,
meta: GraphQLMetaFilter | None = None,
active: GraphQLFilter[bool] | None = None,
active: GraphQLFilterBool | None = None,
) -> list['GraphQLSample']:
filter_ = SampleFilter(
type=type.to_internal_filter() if type else None,
Expand Down Expand Up @@ -751,7 +754,7 @@ async def assays(
self,
info: Info,
root: 'GraphQLSample',
type: GraphQLFilter[str] | None = None,
type: GraphQLFilterStr | None = None,
meta: GraphQLMetaFilter | None = None,
) -> list['GraphQLAssay']:
loader_assays_for_sample_ids = info.context[LoaderKeys.ASSAYS_FOR_SAMPLES]
Expand All @@ -774,12 +777,12 @@ async def sequencing_groups(
self,
info: Info,
root: 'GraphQLSample',
id: GraphQLFilter[str] | None = None,
type: GraphQLFilter[str] | None = None,
technology: GraphQLFilter[str] | None = None,
platform: GraphQLFilter[str] | None = None,
id: GraphQLFilterStr | None = None,
type: GraphQLFilterStr | None = None,
technology: GraphQLFilterStr | None = None,
platform: GraphQLFilterStr | None = None,
meta: GraphQLMetaFilter | None = None,
active_only: GraphQLFilter[bool] | None = None,
active_only: GraphQLFilterBool | None = None,
) -> list['GraphQLSequencingGroup']:
loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES]

Expand Down Expand Up @@ -847,11 +850,11 @@ async def analyses(
self,
info: Info,
root: 'GraphQLSequencingGroup',
status: GraphQLFilter[GraphQLAnalysisStatus] | None = None,
type: GraphQLFilter[str] | None = None,
status: GraphQLFilterAnalysisStatus | None = None,
type: GraphQLFilterStr | None = None,
meta: GraphQLMetaFilter | None = None,
active: GraphQLFilter[bool] | None = None,
project: GraphQLFilter[str] | None = None,
active: GraphQLFilterBool | None = None,
project: GraphQLFilterStr | None = None,
) -> list[GraphQLAnalysis]:
connection = info.context['connection']
loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS]
Expand Down Expand Up @@ -997,8 +1000,8 @@ def enum(self, info: Info) -> GraphQLEnum: # type: ignore
async def cohort_templates(
self,
info: Info,
id: GraphQLFilter[str] | None = None,
project: GraphQLFilter[str] | None = None,
id: GraphQLFilterStr | None = None,
project: GraphQLFilterStr | None = None,
) -> list[GraphQLCohortTemplate]:
connection = info.context['connection']
cohort_layer = CohortLayer(connection)
Expand Down Expand Up @@ -1046,11 +1049,11 @@ async def cohort_templates(
async def cohorts(
self,
info: Info,
id: GraphQLFilter[str] | None = None,
project: GraphQLFilter[str] | None = None,
name: GraphQLFilter[str] | None = None,
author: GraphQLFilter[str] | None = None,
template_id: GraphQLFilter[str] | None = None,
id: GraphQLFilterStr | None = None,
project: GraphQLFilterStr | None = None,
name: GraphQLFilterStr | None = None,
author: GraphQLFilterStr | None = None,
template_id: GraphQLFilterStr | None = None,
) -> list[GraphQLCohort]:
connection = info.context['connection']
cohort_layer = CohortLayer(connection)
Expand Down Expand Up @@ -1098,13 +1101,13 @@ async def project(self, info: Info, name: str) -> GraphQLProject:
async def sample(
self,
info: Info,
id: GraphQLFilter[str] | None = None,
project: GraphQLFilter[str] | None = None,
type: GraphQLFilter[str] | None = None,
id: GraphQLFilterStr | None = None,
project: GraphQLFilterStr | None = None,
type: GraphQLFilterStr | None = None,
meta: GraphQLMetaFilter | None = None,
external_id: GraphQLFilter[str] | None = None,
participant_id: GraphQLFilter[int] | None = None,
active: GraphQLFilter[bool] | None = None,
external_id: GraphQLFilterStr | None = None,
participant_id: GraphQLFilterInt | None = None,
active: GraphQLFilterBool | None = None,
) -> list[GraphQLSample]:
connection = info.context['connection']
ptable = ProjectPermissionsTable(connection)
Expand Down Expand Up @@ -1148,14 +1151,14 @@ async def sample(
async def sequencing_groups(
self,
info: Info,
id: GraphQLFilter[str] | None = None,
project: GraphQLFilter[str] | None = None,
sample_id: GraphQLFilter[str] | None = None,
type: GraphQLFilter[str] | None = None,
technology: GraphQLFilter[str] | None = None,
platform: GraphQLFilter[str] | None = None,
active_only: GraphQLFilter[bool] | None = None,
created_on: GraphQLFilter[datetime.date] | None = None,
id: GraphQLFilterStr | None = None,
project: GraphQLFilterStr | None = None,
sample_id: GraphQLFilterStr | None = None,
type: GraphQLFilterStr | None = None,
technology: GraphQLFilterStr | None = None,
platform: GraphQLFilterStr | None = None,
active_only: GraphQLFilterBool | None = None,
created_on: GraphQLFilterDate | None = None,
assay_meta: GraphQLMetaFilter | None = None,
has_cram: bool | None = None,
has_gvcf: bool | None = None,
Expand Down Expand Up @@ -1254,7 +1257,9 @@ async def analysis_runner(


schema = strawberry.Schema(
query=Query, mutation=None, extensions=[QueryDepthLimiter(max_depth=10)]
query=Query,
mutation=None,
extensions=[QueryDepthLimiter(max_depth=10)],
)
MetamistGraphQLRouter: GraphQLRouter = GraphQLRouter(
schema, graphiql=True, context_getter=get_context
Expand Down

0 comments on commit 71ffd6b

Please sign in to comment.