Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not block graphql_transport_ws operations while creating context or validating a single request operation #2829

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3445d24
Add tests for blocking operation validation
kristjanvalur Mar 8, 2023
2be1bf4
Move validation into the task
kristjanvalur Apr 8, 2023
02772ee
Add some error/validation test cases
kristjanvalur Apr 9, 2023
7687c4e
Use duck typing to detect an ExecutionResult/GraphQLExeucitonResult
kristjanvalur Apr 9, 2023
6198007
add an async context getter for tests which is easily patchable.
kristjanvalur Apr 9, 2023
9ef3bf0
Add tests to ensure context_getter does not block connection
kristjanvalur Apr 9, 2023
7cf537e
Move context getter and root getter into worker task
kristjanvalur Apr 9, 2023
e41ffe9
Catch top level errors
kristjanvalur Jun 8, 2023
c3d6447
Add a test for the task error handler
kristjanvalur Jun 9, 2023
a9589e4
add release.md
kristjanvalur May 9, 2023
ca229a7
Remove dead code, fix coverage
kristjanvalur Jun 30, 2023
1a62b1a
remove special case for AsyncMock
kristjanvalur Aug 3, 2023
76716d0
Add "no cover" to schema code which is designed to not be hit.
kristjanvalur Nov 8, 2023
1ad929b
Update tests for litestar
kristjanvalur Mar 31, 2024
8ced4e4
Litestar integration must be excluded from long test, like Starlite.
kristjanvalur Mar 31, 2024
35a8e68
coverage
kristjanvalur Mar 31, 2024
4084639
Mark some test schema methods as no cover since they are not always used
kristjanvalur Apr 2, 2024
d17a4d4
Mypy support for SubscriptionExecutionResult
kristjanvalur Sep 7, 2024
a1d0695
ruff
kristjanvalur Sep 7, 2024
e43aca8
Remove unused method for coverage
kristjanvalur Sep 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: patch

Operations over `graphql-transport-ws` now create the Context and perform validation on
the worker `Task`, thus not blocking the websocket from accepting messages.
164 changes: 95 additions & 69 deletions strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Dict,
List,
Optional,
Union,
)

from graphql import ExecutionResult as GraphQLExecutionResult
from graphql import GraphQLError, GraphQLSyntaxError, parse

from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
Expand All @@ -29,6 +30,7 @@
SubscribeMessage,
SubscribeMessagePayload,
)
from strawberry.types import ExecutionResult
from strawberry.types.graphql import OperationType
from strawberry.types.unset import UNSET
from strawberry.utils.debug import pretty_print_graphql_operation
Expand All @@ -41,7 +43,6 @@
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
GraphQLTransportMessage,
)
from strawberry.types import ExecutionResult


class BaseGraphQLTransportWSHandler(ABC):
Expand Down Expand Up @@ -107,7 +108,7 @@ def on_request_accepted(self) -> None:

async def handle_connection_init_timeout(self) -> None:
task = asyncio.current_task()
assert task
assert task is not None # for typecheckers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since we enforce typing a lot here, assert <something> is a well known practice for type checkers, so I don't think the comment is required

Suggested change
assert task is not None # for typecheckers
assert task is not None

try:
delay = self.connection_init_wait_timeout.total_seconds()
await asyncio.sleep(delay=delay)
Expand Down Expand Up @@ -239,92 +240,100 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
message.payload.variables,
)

context = await self.get_context()
if isinstance(context, dict):
context["connection_params"] = self.connection_params
root_value = await self.get_root_value()

# Get an AsyncGenerator yielding the results
if operation_type == OperationType.SUBSCRIPTION:
result_source = await self.schema.subscribe(
query=message.payload.query,
variable_values=message.payload.variables,
operation_name=message.payload.operationName,
context_value=context,
root_value=root_value,
)
else:
# create AsyncGenerator returning a single result
async def get_result_source() -> AsyncIterator[ExecutionResult]:
yield await self.schema.execute( # type: ignore
# The method to start the operation. Will be called on worker
# thread and so may contain long running async calls without
# blocking the main websocket handler.
async def start_operation() -> Union[AsyncGenerator[Any, None], Any]:
# there is some type mismatch here which we need to gloss over with
# the use of Any.
# subscribe() returns
# Union[AsyncIterator[graphql.ExecutionResult], graphql.ExecutionResult]:
# whereas execute() returns strawberry.types.ExecutionResult.
# These execution result types are similar, but not the same.

context = await self.get_context()
if isinstance(context, dict):
context["connection_params"] = self.connection_params
root_value = await self.get_root_value()

if operation_type == OperationType.SUBSCRIPTION:
return await self.schema.subscribe(
query=message.payload.query,
variable_values=message.payload.variables,
operation_name=message.payload.operationName,
context_value=context,
root_value=root_value,
)
else:
# single results behave similarly to subscriptions,
# return either a ExecutionResult or an AsyncGenerator
result = await self.schema.execute(
query=message.payload.query,
variable_values=message.payload.variables,
context_value=context,
root_value=root_value,
operation_name=message.payload.operationName,
)
# Note: result may be SubscriptionExecutionResult or ExecutionResult
# now, but we don't support the former properly yet, hence the "ignore" below.

result_source = get_result_source()
# Both validation and execution errors are handled the same way.
if isinstance(result, ExecutionResult) and result.errors:
return result

operation = Operation(self, message.id, operation_type)
# create AsyncGenerator returning a single result
async def single_result() -> AsyncIterator[ExecutionResult]:
yield result # type: ignore

# Handle initial validation errors
if isinstance(result_source, GraphQLExecutionResult):
assert operation_type == OperationType.SUBSCRIPTION
assert result_source.errors
payload = [err.formatted for err in result_source.errors]
await self.send_message(ErrorMessage(id=message.id, payload=payload))
self.schema.process_errors(result_source.errors)
return
return single_result()

# Create task to handle this subscription, reserve the operation ID
operation.task = asyncio.create_task(
self.operation_task(result_source, operation)
)
operation = Operation(self, message.id, operation_type, start_operation)
operation.task = asyncio.create_task(self.operation_task(operation))
self.operations[message.id] = operation

async def operation_task(
self, result_source: AsyncGenerator, operation: Operation
) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()
async def operation_task(self, operation: Operation) -> None:
"""The operation task's top level method.

Cleans-up and de-registers the operation once it is done.
"""
task = asyncio.current_task()
assert task is not None # for type checkers
try:
await self.handle_async_results(result_source, operation)
except BaseException: # pragma: no cover
# cleanup in case of something really unexpected
# wait for generator to be closed to ensure that any existing
# 'finally' statement is called
with suppress(RuntimeError):
await result_source.aclose()
if operation.id in self.operations:
del self.operations[operation.id]
await self.handle_operation(operation)
except asyncio.CancelledError:
raise
else:
await operation.send_message(CompleteMessage(id=operation.id))
except Exception as error:
# Log any unhandled exceptions in the operation task
await self.handle_task_exception(error)
finally:
# add this task to a list to be reaped later
task = asyncio.current_task()
assert task is not None
# Clenaup. Remove the operation from the list of active operations
if operation.id in self.operations:
del self.operations[operation.id]
# TODO: Stop collecting background tasks, not necessary.
# Add this task to a list to be reaped later
self.completed_tasks.append(task)

async def handle_async_results(
async def handle_operation(
self,
result_source: AsyncGenerator,
operation: Operation,
) -> None:
try:
async for result in result_source:
if (
result.errors
and operation.operation_type != OperationType.SUBSCRIPTION
):
error_payload = [err.formatted for err in result.errors]
error_message = ErrorMessage(id=operation.id, payload=error_payload)
await operation.send_message(error_message)
# don't need to call schema.process_errors() here because
# it was already done by schema.execute()
return
else:
result_source = await operation.start_operation()
# result_source is an ExcutionResult-like object or an AsyncGenerator
# Handle validation errors. Cannot check type directly.
if hasattr(result_source, "errors"):
assert result_source.errors
payload = [err.formatted for err in result_source.errors]
await operation.send_message(
ErrorMessage(id=operation.id, payload=payload)
)
if operation.operation_type == OperationType.SUBSCRIPTION:
self.schema.process_errors(result_source.errors)
return

try:
async for result in result_source:
next_payload = {"data": result.data}
if result.errors:
self.schema.process_errors(result.errors)
Expand All @@ -333,6 +342,11 @@ async def handle_async_results(
]
next_message = NextMessage(id=operation.id, payload=next_payload)
await operation.send_message(next_message)
await operation.send_message(CompleteMessage(id=operation.id))
finally:
# Close the AsyncGenerator in case of errors or cancellation
await result_source.aclose()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be wrapped in a with suppress(asyncio.CancelledError)?


except Exception as error:
# GraphQLErrors are handled by graphql-core and included in the
# ExecutionResult
Expand Down Expand Up @@ -378,23 +392,35 @@ async def reap_completed_tasks(self) -> None:
class Operation:
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""

__slots__ = ["handler", "id", "operation_type", "completed", "task"]
__slots__ = [
"handler",
"id",
"operation_type",
"start_operation",
"completed",
"task",
]

def __init__(
self,
handler: BaseGraphQLTransportWSHandler,
id: str,
operation_type: OperationType,
start_operation: Callable[
[], Coroutine[Any, Any, Union[Any, AsyncGenerator[Any, None]]]
],
) -> None:
self.handler = handler
self.id = id
self.operation_type = operation_type
self.start_operation = start_operation
self.completed = False
self.task: Optional[asyncio.Task] = None

async def send_message(self, message: GraphQLTransportMessage) -> None:
# defensive check, should never happen
if self.completed:
return
return # pragma: no cover
if isinstance(message, (CompleteMessage, ErrorMessage)):
self.completed = True
# de-register the operation _before_ sending the final message
Expand Down
4 changes: 2 additions & 2 deletions tests/http/clients/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema

from ..context import get_context
from ..context import get_context_async as get_context
from .base import (
JSON,
DebuggableGraphQLTransportWSMixin,
Expand Down Expand Up @@ -50,7 +50,7 @@ async def get_context(
) -> object:
context = await super().get_context(request, response)

return get_context(context)
return await get_context(context)

async def get_root_value(self, request: web.Request) -> Query:
await super().get_root_value(request) # for coverage
Expand Down
4 changes: 2 additions & 2 deletions tests/http/clients/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema

from ..context import get_context
from ..context import get_context_async as get_context
from .base import (
JSON,
DebuggableGraphQLTransportWSMixin,
Expand Down Expand Up @@ -56,7 +56,7 @@ async def get_context(
) -> object:
context = await super().get_context(request, response)

return get_context(context)
return await get_context(context)

async def process_result(
self, request: Request, result: ExecutionResult
Expand Down
6 changes: 3 additions & 3 deletions tests/http/clients/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from strawberry.http.typevars import Context, RootValue
from tests.views.schema import Query, schema

from ..context import get_context
from ..context import get_context, get_context_async
from .base import (
JSON,
HttpClient,
Expand Down Expand Up @@ -77,7 +77,7 @@ async def get_context(self, *args: str, **kwargs: Any) -> object:
context["connectionInitTimeoutTask"] = getattr(
self._handler, "connection_init_timeout_task", None
)
for key, val in get_context({}).items():
for key, val in (await get_context_async({})).items():
context[key] = val
return context

Expand All @@ -95,7 +95,7 @@ async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]
async def get_context(self, request: ChannelsConsumer, response: Any) -> Context:
context = await super().get_context(request, response)

return get_context(context)
return await get_context_async(context)

async def process_result(
self, request: ChannelsConsumer, result: Any
Expand Down
4 changes: 2 additions & 2 deletions tests/http/clients/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema

from ..context import get_context
from ..context import get_context_async as get_context
from .asgi import AsgiWebSocketClient
from .base import (
JSON,
Expand Down Expand Up @@ -50,7 +50,7 @@ async def fastapi_get_context(
ws: WebSocket = None, # type: ignore
custom_value: str = Depends(custom_context_dependency),
) -> Dict[str, object]:
return get_context(
return await get_context(
{
"request": request or ws,
"background_tasks": background_tasks,
Expand Down
8 changes: 2 additions & 6 deletions tests/http/clients/litestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema

from ..context import get_context
from ..context import get_context_async as get_context
from .base import (
JSON,
DebuggableGraphQLTransportWSMixin,
Expand All @@ -30,12 +30,8 @@
)


def custom_context_dependency() -> str:
return "Hi!"


async def litestar_get_context(request: Request = None):
return get_context({"request": request})
return await get_context({"request": request})


async def get_root_value(request: Request = None):
Expand Down
15 changes: 15 additions & 0 deletions tests/http/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@


def get_context(context: object) -> Dict[str, object]:
return get_context_inner(context)


# a patchable method for unittests
def get_context_inner(context: object) -> Dict[str, object]:
assert isinstance(context, dict)
return {**context, "custom_value": "a value from context"}


# async version for async frameworks
async def get_context_async(context: object) -> Dict[str, object]:
return await get_context_async_inner(context)


# a patchable method for unittests
async def get_context_async_inner(context: object) -> Dict[str, object]:
assert isinstance(context, dict)
return {**context, "custom_value": "a value from context"}
Loading
Loading