Skip to content

Commit

Permalink
Integrate websockets into the async base view (#3638)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Patrick Arminio <[email protected]>
  • Loading branch information
3 people authored Oct 5, 2024
1 parent 1b33547 commit 40a0504
Show file tree
Hide file tree
Showing 52 changed files with 891 additions and 1,638 deletions.
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: minor

Starting with this release, WebSocket logic now lives in the base class shared between all HTTP integrations.
This makes the behaviour of WebSockets much more consistent between integrations and easier to maintain.
7 changes: 7 additions & 0 deletions TWEET.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
🚀 Starting with Strawberry $version, WebSocket logic now lives in the base
class shared across all HTTP integrations. More consistent behavior and easier
maintenance for WebSockets across integrations. 🎉

Thanks to $contributor for the PR 👏

$release_url
6 changes: 0 additions & 6 deletions strawberry/aiohttp/handlers/__init__.py

This file was deleted.

62 changes: 0 additions & 62 deletions strawberry/aiohttp/handlers/graphql_transport_ws_handler.py

This file was deleted.

69 changes: 0 additions & 69 deletions strawberry/aiohttp/handlers/graphql_ws_handler.py

This file was deleted.

102 changes: 58 additions & 44 deletions strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from datetime import timedelta
from io import BytesIO
from json.decoder import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -16,15 +17,16 @@
Union,
cast,
)
from typing_extensions import TypeGuard

from aiohttp import web
from aiohttp import http, web
from aiohttp.multipart import BodyPartReader
from strawberry.aiohttp.handlers import (
GraphQLTransportWSHandler,
GraphQLWSHandler,
from strawberry.http.async_base_view import (
AsyncBaseHTTPView,
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
from strawberry.http.exceptions import HTTPException
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
Expand Down Expand Up @@ -79,11 +81,36 @@ def content_type(self) -> Optional[str]:
return self.headers.get("content-type")


class AioHTTPWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None:
self.request = request
self.ws = ws

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async for ws_message in self.ws:
if ws_message.type == http.WSMsgType.TEXT:
try:
yield ws_message.json()
except JSONDecodeError:
raise NonJsonMessageReceived()

elif ws_message.type == http.WSMsgType.BINARY:
raise NonJsonMessageReceived()

async def send_json(self, message: Mapping[str, object]) -> None:
await self.ws.send_json(message)

async def close(self, code: int, reason: str) -> None:
await self.ws.close(code=code, message=reason.encode())


class GraphQLView(
AsyncBaseHTTPView[
web.Request,
Union[web.Response, web.StreamResponse],
web.Response,
web.Request,
web.WebSocketResponse,
Context,
RootValue,
]
Expand All @@ -92,10 +119,9 @@ class GraphQLView(
# bare handler function.
_is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined]

graphql_transport_ws_handler_class = GraphQLTransportWSHandler
graphql_ws_handler_class = GraphQLWSHandler
allow_queries_via_get = True
request_adapter_class = AioHTTPRequestAdapter
websocket_adapter_class = AioHTTPWebSocketAdapter

def __init__(
self,
Expand Down Expand Up @@ -138,48 +164,36 @@ async def render_graphql_ide(self, request: web.Request) -> web.Response:
async def get_sub_response(self, request: web.Request) -> web.Response:
return web.Response()

async def __call__(self, request: web.Request) -> web.StreamResponse:
def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]:
ws = web.WebSocketResponse(protocols=self.subscription_protocols)
ws_test = ws.can_prepare(request)

if not ws_test.ok:
try:
return await self.run(request=request)
except HTTPException as e:
return web.Response(
body=e.reason,
status=e.status_code,
)

if ws_test.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
return await self.graphql_transport_ws_handler_class(
schema=self.schema,
debug=self.debug,
connection_init_wait_timeout=self.connection_init_wait_timeout,
get_context=self.get_context, # type: ignore
get_root_value=self.get_root_value,
request=request,
).handle()
elif ws_test.protocol == GRAPHQL_WS_PROTOCOL:
return await self.graphql_ws_handler_class(
schema=self.schema,
debug=self.debug,
keep_alive=self.keep_alive,
keep_alive_interval=self.keep_alive_interval,
get_context=self.get_context,
get_root_value=self.get_root_value,
request=request,
).handle()
else:
await ws.prepare(request)
await ws.close(code=4406, message=b"Subprotocol not acceptable")
return ws
return ws.can_prepare(request).ok

async def pick_websocket_subprotocol(self, request: web.Request) -> Optional[str]:
ws = web.WebSocketResponse(protocols=self.subscription_protocols)
return ws.can_prepare(request).protocol

async def create_websocket_response(
self, request: web.Request, subprotocol: Optional[str]
) -> web.WebSocketResponse:
protocols = [subprotocol] if subprotocol else []
ws = web.WebSocketResponse(protocols=protocols)
await ws.prepare(request)
return ws

async def __call__(self, request: web.Request) -> web.StreamResponse:
try:
return await self.run(request=request)
except HTTPException as e:
return web.Response(
body=e.reason,
status=e.status_code,
)

async def get_root_value(self, request: web.Request) -> Optional[RootValue]:
return None

async def get_context(
self, request: web.Request, response: web.Response
self, request: web.Request, response: Union[web.Response, web.WebSocketResponse]
) -> Context:
return {"request": request, "response": response} # type: ignore

Expand Down
Loading

0 comments on commit 40a0504

Please sign in to comment.