diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index d817910d2c..89b0c718e8 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -16,7 +16,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -39,7 +39,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return await get_context(context) + return get_context(context) async def get_root_value(self, request: web.Request) -> Query: await super().get_root_value(request) # for coverage diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 335ec06aa9..7910e02f73 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -45,7 +45,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return await get_context(context) + return get_context(context) async def process_result( self, request: Request, result: ExecutionResult diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 7e0627572c..bde2364128 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -20,7 +20,7 @@ from strawberry.http.typevars import Context, RootValue from tests.views.schema import Query, schema -from ..context import get_context, get_context_async +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -78,7 +78,7 @@ async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue] async def get_context(self, request: ChannelsConsumer, response: Any) -> Context: context = await super().get_context(request, response) - return await get_context_async(context) + return get_context(context) async def process_result( self, request: ChannelsConsumer, result: Any diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 4da966ce40..b1b80625fa 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -16,7 +16,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .asgi import AsgiWebSocketClient from .base import ( JSON, @@ -39,7 +39,7 @@ async def fastapi_get_context( ws: WebSocket = None, # type: ignore custom_value: str = Depends(custom_context_dependency), ) -> Dict[str, object]: - return await get_context( + return get_context( { "request": request or ws, "background_tasks": background_tasks, diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 661c8fc1ee..b29be52b32 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -16,7 +16,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -30,7 +30,7 @@ async def litestar_get_context(request: Request = None): - return await get_context({"request": request}) + return get_context({"request": request}) async def get_root_value(request: Request = None): diff --git a/tests/http/context.py b/tests/http/context.py index c1ce5dbecf..99985b2434 100644 --- a/tests/http/context.py +++ b/tests/http/context.py @@ -2,21 +2,6 @@ def get_context(context: object) -> Dict[str, object]: - return get_context_inner(context) - - -# a patchable method for unittests -def get_context_inner(context: object) -> Dict[str, object]: assert isinstance(context, dict) - return {**context, "custom_value": "a value from context"} - -# async version for async frameworks -async def get_context_async(context: object) -> Dict[str, object]: - return await get_context_async_inner(context) - - -# a patchable method for unittests -async def get_context_async_inner(context: object) -> Dict[str, object]: - assert isinstance(context, dict) return {**context, "custom_value": "a value from context"}