Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the CI lint job pass again #3691

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions strawberry/channels/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,17 @@ async def gql_init(self) -> None:
await self.send_json_to(
ConnectionInitMessage(payload=self.connection_params).as_dict()
)
response = await self.receive_json_from()
assert response == ConnectionAckMessage().as_dict()
graphql_transport_ws_response = await self.receive_json_from()
assert graphql_transport_ws_response == ConnectionAckMessage().as_dict()
else:
assert res == (True, GRAPHQL_WS_PROTOCOL)
await self.send_json_to(
GraphQLWSConnectionInitMessage({"type": "connection_init"})
)
response: GraphQLWSConnectionAckMessage = await self.receive_json_from()
assert response["type"] == "connection_ack"
graphql_ws_response: GraphQLWSConnectionAckMessage = (
await self.receive_json_from()
)
assert graphql_ws_response["type"] == "connection_ack"

# Actual `ExecutionResult`` objects are not available client-side, since they
# get transformed into `FormattedExecutionResult` on the wire, but we attempt
Expand Down
14 changes: 9 additions & 5 deletions strawberry/cli/commands/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import importlib
import inspect
from pathlib import Path # noqa: TCH003
from typing import List, Optional, Type
from typing import List, Optional, Type, Union, cast

import rich
import typer
Expand Down Expand Up @@ -61,7 +61,9 @@ def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]:


@functools.lru_cache
def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]:
def _load_plugin(
plugin_path: str,
) -> Union[Type[QueryCodegenPlugin], Type[ConsolePlugin]]:
# try to import plugin_name from current folder
# then try to import from strawberry.codegen.plugins

Expand All @@ -77,7 +79,9 @@ def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]:
return plugin


def _load_plugins(plugin_ids: List[str], query: Path) -> List[QueryCodegenPlugin]:
def _load_plugins(
plugin_ids: List[str], query: Path
) -> List[Union[QueryCodegenPlugin, ConsolePlugin]]:
plugins = []
for ptype_id in plugin_ids:
ptype = _load_plugin(ptype_id)
Expand Down Expand Up @@ -127,11 +131,11 @@ def codegen(

console_plugin_type = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin
console_plugin = console_plugin_type(output_dir)
assert isinstance(console_plugin, ConsolePlugin)
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
console_plugin.before_any_start()

for q in query:
plugins = _load_plugins(selected_plugins, q)
console_plugin.query = q # update the query in the console plugin.
plugins = cast(List[QueryCodegenPlugin], _load_plugins(selected_plugins, q))

code_generator = QueryCodegen(
schema_symbol, plugins=plugins, console_plugin=console_plugin
Expand Down
4 changes: 2 additions & 2 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _build_dataclass_creation_fields(

return DataclassCreationFields(
name=field.name,
field_type=field_type,
field_type=field_type, # type: ignore
field=strawberry_field,
)

Expand Down Expand Up @@ -198,7 +198,7 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]:
all_model_fields = [
DataclassCreationFields(
name=field.name,
field_type=field.type,
field_type=field.type, # type: ignore
field=field,
)
for field in extra_fields + private_fields
Expand Down
15 changes: 8 additions & 7 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,17 @@ async def handle_async_results(
await self.websocket.send_json(error_message)
else:
self.subscriptions[operation_id] = agen_or_err

async for result in agen_or_err:
await self.send_data(result, operation_id)
complete_message: CompleteMessage = {
"type": "complete",
"id": operation_id,
}
await self.websocket.send_json(complete_message)

await self.websocket.send_json(
CompleteMessage({"type": "complete", "id": operation_id})
)
except asyncio.CancelledError:
complete_message: CompleteMessage = {"type": "complete", "id": operation_id}
await self.websocket.send_json(complete_message)
await self.websocket.send_json(
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
CompleteMessage({"type": "complete", "id": operation_id})
)

async def cleanup_operation(self, operation_id: str) -> None:
if operation_id in self.subscriptions:
Expand Down
Loading