From 07e7165b9c06dcd58c663280f82662385f22a494 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Thu, 7 Nov 2024 17:06:41 +0100 Subject: [PATCH 1/3] Fix variable redefinitions --- strawberry/channels/testing.py | 10 ++++++---- .../protocols/graphql_ws/handlers.py | 15 ++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index ec129263b5..511ed34d1b 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -111,15 +111,17 @@ async def gql_init(self) -> None: await self.send_json_to( ConnectionInitMessage(payload=self.connection_params).as_dict() ) - response = await self.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + graphql_transport_ws_response = await self.receive_json_from() + assert graphql_transport_ws_response == ConnectionAckMessage().as_dict() else: assert res == (True, GRAPHQL_WS_PROTOCOL) await self.send_json_to( GraphQLWSConnectionInitMessage({"type": "connection_init"}) ) - response: GraphQLWSConnectionAckMessage = await self.receive_json_from() - assert response["type"] == "connection_ack" + graphql_ws_response: GraphQLWSConnectionAckMessage = ( + await self.receive_json_from() + ) + assert graphql_ws_response["type"] == "connection_ack" # Actual `ExecutionResult`` objects are not available client-side, since they # get transformed into `FormattedExecutionResult` on the wire, but we attempt diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 9b2eddbf85..0eda72cb56 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -169,16 +169,17 @@ async def handle_async_results( await self.websocket.send_json(error_message) else: self.subscriptions[operation_id] = agen_or_err + async for result in agen_or_err: await self.send_data(result, operation_id) - complete_message: CompleteMessage = { - "type": "complete", - "id": operation_id, - } - await self.websocket.send_json(complete_message) + + await self.websocket.send_json( + CompleteMessage({"type": "complete", "id": operation_id}) + ) except asyncio.CancelledError: - complete_message: CompleteMessage = {"type": "complete", "id": operation_id} - await self.websocket.send_json(complete_message) + await self.websocket.send_json( + CompleteMessage({"type": "complete", "id": operation_id}) + ) async def cleanup_operation(self, operation_id: str) -> None: if operation_id in self.subscriptions: From e4b6889df1d74f97c139cb19c6338ba7f655638b Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Thu, 7 Nov 2024 17:08:48 +0100 Subject: [PATCH 2/3] Ignore trust issues --- strawberry/experimental/pydantic/object_type.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index c153b84a0b..c9e1bb1161 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -105,7 +105,7 @@ def _build_dataclass_creation_fields( return DataclassCreationFields( name=field.name, - field_type=field_type, + field_type=field_type, # type: ignore field=strawberry_field, ) @@ -198,7 +198,7 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: all_model_fields = [ DataclassCreationFields( name=field.name, - field_type=field.type, + field_type=field.type, # type: ignore field=field, ) for field in extra_fields + private_fields From 07b286ad5d718af7493c6a7be732c8b154b92f49 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Thu, 7 Nov 2024 17:11:05 +0100 Subject: [PATCH 3/3] Narrow types of dynamically loaded plugins --- strawberry/cli/commands/codegen.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/strawberry/cli/commands/codegen.py b/strawberry/cli/commands/codegen.py index 6e398a54f0..6fe784da30 100644 --- a/strawberry/cli/commands/codegen.py +++ b/strawberry/cli/commands/codegen.py @@ -4,7 +4,7 @@ import importlib import inspect from pathlib import Path # noqa: TCH003 -from typing import List, Optional, Type +from typing import List, Optional, Type, Union, cast import rich import typer @@ -61,7 +61,9 @@ def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]: @functools.lru_cache -def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]: +def _load_plugin( + plugin_path: str, +) -> Union[Type[QueryCodegenPlugin], Type[ConsolePlugin]]: # try to import plugin_name from current folder # then try to import from strawberry.codegen.plugins @@ -77,7 +79,9 @@ def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]: return plugin -def _load_plugins(plugin_ids: List[str], query: Path) -> List[QueryCodegenPlugin]: +def _load_plugins( + plugin_ids: List[str], query: Path +) -> List[Union[QueryCodegenPlugin, ConsolePlugin]]: plugins = [] for ptype_id in plugin_ids: ptype = _load_plugin(ptype_id) @@ -127,11 +131,11 @@ def codegen( console_plugin_type = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin console_plugin = console_plugin_type(output_dir) + assert isinstance(console_plugin, ConsolePlugin) console_plugin.before_any_start() for q in query: - plugins = _load_plugins(selected_plugins, q) - console_plugin.query = q # update the query in the console plugin. + plugins = cast(List[QueryCodegenPlugin], _load_plugins(selected_plugins, q)) code_generator = QueryCodegen( schema_symbol, plugins=plugins, console_plugin=console_plugin