-
-
Notifications
You must be signed in to change notification settings - Fork 535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do not block graphql_transport_ws
operations while creating context or validating a single request operation
#2829
Closed
kristjanvalur
wants to merge
20
commits into
strawberry-graphql:main
from
mainframeindustries:kristjan/validate-in-task
Closed
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
3445d24
Add tests for blocking operation validation
kristjanvalur 2be1bf4
Move validation into the task
kristjanvalur 02772ee
Add some error/validation test cases
kristjanvalur 7687c4e
Use duck typing to detect an ExecutionResult/GraphQLExeucitonResult
kristjanvalur 6198007
add an async context getter for tests which is easily patchable.
kristjanvalur 9ef3bf0
Add tests to ensure context_getter does not block connection
kristjanvalur 7cf537e
Move context getter and root getter into worker task
kristjanvalur e41ffe9
Catch top level errors
kristjanvalur c3d6447
Add a test for the task error handler
kristjanvalur a9589e4
add release.md
kristjanvalur ca229a7
Remove dead code, fix coverage
kristjanvalur 1a62b1a
remove special case for AsyncMock
kristjanvalur 76716d0
Add "no cover" to schema code which is designed to not be hit.
kristjanvalur 1ad929b
Update tests for litestar
kristjanvalur 8ced4e4
Litestar integration must be excluded from long test, like Starlite.
kristjanvalur 35a8e68
coverage
kristjanvalur 4084639
Mark some test schema methods as no cover since they are not always used
kristjanvalur d17a4d4
Mypy support for SubscriptionExecutionResult
kristjanvalur a1d0695
ruff
kristjanvalur e43aca8
Remove unused method for coverage
kristjanvalur File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Release type: patch | ||
|
||
Operations over `graphql-transport-ws` now create the Context and perform validation on | ||
the worker `Task`, thus not blocking the websocket from accepting messages. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,12 +10,13 @@ | |
AsyncGenerator, | ||
AsyncIterator, | ||
Callable, | ||
Coroutine, | ||
Dict, | ||
List, | ||
Optional, | ||
Union, | ||
) | ||
|
||
from graphql import ExecutionResult as GraphQLExecutionResult | ||
from graphql import GraphQLError, GraphQLSyntaxError, parse | ||
|
||
from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( | ||
|
@@ -29,6 +30,7 @@ | |
SubscribeMessage, | ||
SubscribeMessagePayload, | ||
) | ||
from strawberry.types import ExecutionResult | ||
from strawberry.types.graphql import OperationType | ||
from strawberry.types.unset import UNSET | ||
from strawberry.utils.debug import pretty_print_graphql_operation | ||
|
@@ -41,7 +43,6 @@ | |
from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( | ||
GraphQLTransportMessage, | ||
) | ||
from strawberry.types import ExecutionResult | ||
|
||
|
||
class BaseGraphQLTransportWSHandler(ABC): | ||
|
@@ -107,7 +108,7 @@ def on_request_accepted(self) -> None: | |
|
||
async def handle_connection_init_timeout(self) -> None: | ||
task = asyncio.current_task() | ||
assert task | ||
assert task is not None # for typecheckers | ||
try: | ||
delay = self.connection_init_wait_timeout.total_seconds() | ||
await asyncio.sleep(delay=delay) | ||
|
@@ -239,92 +240,100 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: | |
message.payload.variables, | ||
) | ||
|
||
context = await self.get_context() | ||
if isinstance(context, dict): | ||
context["connection_params"] = self.connection_params | ||
root_value = await self.get_root_value() | ||
|
||
# Get an AsyncGenerator yielding the results | ||
if operation_type == OperationType.SUBSCRIPTION: | ||
result_source = await self.schema.subscribe( | ||
query=message.payload.query, | ||
variable_values=message.payload.variables, | ||
operation_name=message.payload.operationName, | ||
context_value=context, | ||
root_value=root_value, | ||
) | ||
else: | ||
# create AsyncGenerator returning a single result | ||
async def get_result_source() -> AsyncIterator[ExecutionResult]: | ||
yield await self.schema.execute( # type: ignore | ||
# The method to start the operation. Will be called on worker | ||
# thread and so may contain long running async calls without | ||
# blocking the main websocket handler. | ||
async def start_operation() -> Union[AsyncGenerator[Any, None], Any]: | ||
# there is some type mismatch here which we need to gloss over with | ||
# the use of Any. | ||
# subscribe() returns | ||
# Union[AsyncIterator[graphql.ExecutionResult], graphql.ExecutionResult]: | ||
# whereas execute() returns strawberry.types.ExecutionResult. | ||
# These execution result types are similar, but not the same. | ||
|
||
context = await self.get_context() | ||
if isinstance(context, dict): | ||
context["connection_params"] = self.connection_params | ||
root_value = await self.get_root_value() | ||
|
||
if operation_type == OperationType.SUBSCRIPTION: | ||
return await self.schema.subscribe( | ||
query=message.payload.query, | ||
variable_values=message.payload.variables, | ||
operation_name=message.payload.operationName, | ||
context_value=context, | ||
root_value=root_value, | ||
) | ||
else: | ||
# single results behave similarly to subscriptions, | ||
# return either a ExecutionResult or an AsyncGenerator | ||
result = await self.schema.execute( | ||
query=message.payload.query, | ||
variable_values=message.payload.variables, | ||
context_value=context, | ||
root_value=root_value, | ||
operation_name=message.payload.operationName, | ||
) | ||
# Note: result may be SubscriptionExecutionResult or ExecutionResult | ||
# now, but we don't support the former properly yet, hence the "ignore" below. | ||
|
||
result_source = get_result_source() | ||
# Both validation and execution errors are handled the same way. | ||
if isinstance(result, ExecutionResult) and result.errors: | ||
return result | ||
|
||
operation = Operation(self, message.id, operation_type) | ||
# create AsyncGenerator returning a single result | ||
async def single_result() -> AsyncIterator[ExecutionResult]: | ||
yield result # type: ignore | ||
|
||
# Handle initial validation errors | ||
if isinstance(result_source, GraphQLExecutionResult): | ||
assert operation_type == OperationType.SUBSCRIPTION | ||
assert result_source.errors | ||
payload = [err.formatted for err in result_source.errors] | ||
await self.send_message(ErrorMessage(id=message.id, payload=payload)) | ||
self.schema.process_errors(result_source.errors) | ||
return | ||
return single_result() | ||
|
||
# Create task to handle this subscription, reserve the operation ID | ||
operation.task = asyncio.create_task( | ||
self.operation_task(result_source, operation) | ||
) | ||
operation = Operation(self, message.id, operation_type, start_operation) | ||
operation.task = asyncio.create_task(self.operation_task(operation)) | ||
self.operations[message.id] = operation | ||
|
||
async def operation_task( | ||
self, result_source: AsyncGenerator, operation: Operation | ||
) -> None: | ||
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" | ||
# TODO: Handle errors in this method using self.handle_task_exception() | ||
async def operation_task(self, operation: Operation) -> None: | ||
"""The operation task's top level method. | ||
|
||
Cleans-up and de-registers the operation once it is done. | ||
""" | ||
task = asyncio.current_task() | ||
assert task is not None # for type checkers | ||
try: | ||
await self.handle_async_results(result_source, operation) | ||
except BaseException: # pragma: no cover | ||
# cleanup in case of something really unexpected | ||
# wait for generator to be closed to ensure that any existing | ||
# 'finally' statement is called | ||
with suppress(RuntimeError): | ||
await result_source.aclose() | ||
if operation.id in self.operations: | ||
del self.operations[operation.id] | ||
await self.handle_operation(operation) | ||
except asyncio.CancelledError: | ||
raise | ||
else: | ||
await operation.send_message(CompleteMessage(id=operation.id)) | ||
except Exception as error: | ||
# Log any unhandled exceptions in the operation task | ||
await self.handle_task_exception(error) | ||
finally: | ||
# add this task to a list to be reaped later | ||
task = asyncio.current_task() | ||
assert task is not None | ||
# Clenaup. Remove the operation from the list of active operations | ||
if operation.id in self.operations: | ||
del self.operations[operation.id] | ||
# TODO: Stop collecting background tasks, not necessary. | ||
# Add this task to a list to be reaped later | ||
self.completed_tasks.append(task) | ||
|
||
async def handle_async_results( | ||
async def handle_operation( | ||
self, | ||
result_source: AsyncGenerator, | ||
operation: Operation, | ||
) -> None: | ||
try: | ||
async for result in result_source: | ||
if ( | ||
result.errors | ||
and operation.operation_type != OperationType.SUBSCRIPTION | ||
): | ||
error_payload = [err.formatted for err in result.errors] | ||
error_message = ErrorMessage(id=operation.id, payload=error_payload) | ||
await operation.send_message(error_message) | ||
# don't need to call schema.process_errors() here because | ||
# it was already done by schema.execute() | ||
return | ||
else: | ||
result_source = await operation.start_operation() | ||
# result_source is an ExcutionResult-like object or an AsyncGenerator | ||
# Handle validation errors. Cannot check type directly. | ||
if hasattr(result_source, "errors"): | ||
assert result_source.errors | ||
payload = [err.formatted for err in result_source.errors] | ||
await operation.send_message( | ||
ErrorMessage(id=operation.id, payload=payload) | ||
) | ||
if operation.operation_type == OperationType.SUBSCRIPTION: | ||
self.schema.process_errors(result_source.errors) | ||
return | ||
|
||
try: | ||
async for result in result_source: | ||
next_payload = {"data": result.data} | ||
if result.errors: | ||
self.schema.process_errors(result.errors) | ||
|
@@ -333,6 +342,11 @@ async def handle_async_results( | |
] | ||
next_message = NextMessage(id=operation.id, payload=next_payload) | ||
await operation.send_message(next_message) | ||
await operation.send_message(CompleteMessage(id=operation.id)) | ||
finally: | ||
# Close the AsyncGenerator in case of errors or cancellation | ||
await result_source.aclose() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be wrapped in a |
||
|
||
except Exception as error: | ||
# GraphQLErrors are handled by graphql-core and included in the | ||
# ExecutionResult | ||
|
@@ -378,23 +392,35 @@ async def reap_completed_tasks(self) -> None: | |
class Operation: | ||
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" | ||
|
||
__slots__ = ["handler", "id", "operation_type", "completed", "task"] | ||
__slots__ = [ | ||
"handler", | ||
"id", | ||
"operation_type", | ||
"start_operation", | ||
"completed", | ||
"task", | ||
] | ||
|
||
def __init__( | ||
self, | ||
handler: BaseGraphQLTransportWSHandler, | ||
id: str, | ||
operation_type: OperationType, | ||
start_operation: Callable[ | ||
[], Coroutine[Any, Any, Union[Any, AsyncGenerator[Any, None]]] | ||
], | ||
) -> None: | ||
self.handler = handler | ||
self.id = id | ||
self.operation_type = operation_type | ||
self.start_operation = start_operation | ||
self.completed = False | ||
self.task: Optional[asyncio.Task] = None | ||
|
||
async def send_message(self, message: GraphQLTransportMessage) -> None: | ||
# defensive check, should never happen | ||
if self.completed: | ||
return | ||
return # pragma: no cover | ||
if isinstance(message, (CompleteMessage, ErrorMessage)): | ||
self.completed = True | ||
# de-register the operation _before_ sending the final message | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: since we enforce typing a lot here,
assert <something>
is a well known practice for type checkers, so I don't think the comment is required