From b2b82b05cfb7071dcd576c703520f67810fd4167 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 6 Aug 2023 06:20:40 -0300 Subject: [PATCH] fix: properly handle async iterables that after sliced are not async iterable anymore (#3014) --- RELEASE.md | 7 +++ strawberry/relay/types.py | 26 +++++++---- tests/relay/test_types.py | 95 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..39d84c0b0b --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +Release type: patch + +This release fixes an issue on `relay.ListConnection` where async iterables that returns +non async iterable objects after being sliced where producing errors. + +This should fix an issue with async strawberry-graphql-django when returning already +prefetched QuerySets. diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py index 57b271b188..0890701dc0 100644 --- a/strawberry/relay/types.py +++ b/strawberry/relay/types.py @@ -871,14 +871,24 @@ async def resolver(): overfetch, ) - assert isinstance(iterator, (AsyncIterator, AsyncIterable)) - edges: List[Edge] = [ - edge_class.resolve_edge( - cls.resolve_node(v, info=info, **kwargs), - cursor=start + i, - ) - async for i, v in aenumerate(iterator) - ] + # The slice above might return an object that now is not async + # iterable anymore (e.g. an already cached django queryset) + if isinstance(iterator, (AsyncIterator, AsyncIterable)): + edges: List[Edge] = [ + edge_class.resolve_edge( + cls.resolve_node(v, info=info, **kwargs), + cursor=start + i, + ) + async for i, v in aenumerate(iterator) + ] + else: + edges: List[Edge] = [ # type: ignore[no-redef] + edge_class.resolve_edge( + cls.resolve_node(v, info=info, **kwargs), + cursor=start + i, + ) + for i, v in enumerate(iterator) + ] has_previous_page = start > 0 if expected is not None and len(edges) == expected + 1: diff --git a/tests/relay/test_types.py b/tests/relay/test_types.py index 542d7cf9fc..6fcb4529f0 100644 --- a/tests/relay/test_types.py +++ b/tests/relay/test_types.py @@ -1,9 +1,11 @@ -from typing import Any, Optional, Union, cast +from typing import Any, AsyncGenerator, AsyncIterable, Optional, Union, cast from typing_extensions import assert_type import pytest +import strawberry from strawberry import relay +from strawberry.relay.utils import to_base64 from strawberry.types.info import Info from .schema import Fruit, FruitAsync, schema @@ -148,3 +150,94 @@ class Foo: gid = relay.GlobalID(type_name="FruitAsync", node_id="1") with pytest.raises(TypeError): fruit = await gid.resolve_node(fake_info, ensure_type=Foo) + + +async def test_resolve_async_list_connection(): + @strawberry.type + class SomeType(relay.Node): + id: relay.NodeID[int] + + @strawberry.type + class Query: + @relay.connection(relay.ListConnection[SomeType]) + async def some_type_conn(self) -> AsyncGenerator[SomeType, None]: + yield SomeType(id=0) + yield SomeType(id=1) + yield SomeType(id=2) + + schema = strawberry.Schema(query=Query) + ret = await schema.execute( + """\ + query { + someTypeConn { + edges { + node { + id + } + } + } + } + """ + ) + assert ret.errors is None + assert ret.data == { + "someTypeConn": { + "edges": [ + {"node": {"id": to_base64("SomeType", 0)}}, + {"node": {"id": to_base64("SomeType", 1)}}, + {"node": {"id": to_base64("SomeType", 2)}}, + ], + } + } + + +async def test_resolve_async_list_connection_but_sync_after_sliced(): + # We are mimicking an object which is async iterable, but when sliced + # returns something that is not anymore. This is similar to an already + # prefetched django QuerySet, which is async iterable by default, but + # when sliced, since it is already prefetched, will return a list. + class Slicer: + def __init__(self, nodes) -> None: + self.nodes = nodes + + async def __aiter__(self): + for n in self.nodes: + yield n + + def __getitem__(self, key): + return self.nodes[key] + + @strawberry.type + class SomeType(relay.Node): + id: relay.NodeID[int] + + @strawberry.type + class Query: + @relay.connection(relay.ListConnection[SomeType]) + async def some_type_conn(self) -> AsyncIterable[SomeType]: + return Slicer([SomeType(id=0), SomeType(id=1), SomeType(id=2)]) + + schema = strawberry.Schema(query=Query) + ret = await schema.execute( + """\ + query { + someTypeConn { + edges { + node { + id + } + } + } + } + """ + ) + assert ret.errors is None + assert ret.data == { + "someTypeConn": { + "edges": [ + {"node": {"id": to_base64("SomeType", 0)}}, + {"node": {"id": to_base64("SomeType", 1)}}, + {"node": {"id": to_base64("SomeType", 2)}}, + ], + } + }