From ddcc9528add387a6de410bb0a9aad102180722ba Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 12 Oct 2024 13:55:41 -0300 Subject: [PATCH] feat: Allow connection to be used without relay for generic cursor pagination --- RELEASE.md | 16 + ...tion.md => connection-wrong-annotation.md} | 0 ...> connection-wrong-resolver-annotation.md} | 26 +- docs/guides/pagination/connections.md | 622 ++++++------------ docs/guides/pagination/overview.md | 6 +- docs/guides/relay.md | 220 +------ strawberry/pagination/__init__.py | 21 + strawberry/pagination/exceptions.py | 82 +++ strawberry/pagination/fields.py | 338 ++++++++++ strawberry/pagination/types.py | 376 +++++++++++ strawberry/pagination/utils.py | 198 ++++++ strawberry/relay/__init__.py | 53 +- strawberry/relay/exceptions.py | 74 +-- strawberry/relay/fields.py | 333 +--------- strawberry/relay/types.py | 379 +---------- strawberry/relay/utils.py | 224 +------ strawberry/utils.py | 0 tests/pagination/__init__.py | 0 tests/pagination/test_exceptions.py | 70 ++ tests/pagination/test_types.py | 146 ++++ tests/{relay => pagination}/test_utils.py | 15 +- tests/relay/test_exceptions.py | 65 +- tests/relay/test_types.py | 130 +--- 23 files changed, 1620 insertions(+), 1774 deletions(-) create mode 100644 RELEASE.md rename docs/errors/{relay-wrong-annotation.md => connection-wrong-annotation.md} (100%) rename docs/errors/{relay-wrong-resolver-annotation.md => connection-wrong-resolver-annotation.md} (69%) create mode 100644 strawberry/pagination/__init__.py create mode 100644 strawberry/pagination/exceptions.py create mode 100644 strawberry/pagination/fields.py create mode 100644 strawberry/pagination/types.py create mode 100644 strawberry/pagination/utils.py create mode 100644 strawberry/utils.py create mode 100644 tests/pagination/__init__.py create mode 100644 tests/pagination/test_exceptions.py create mode 100644 tests/pagination/test_types.py rename tests/{relay => pagination}/test_utils.py (94%) diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..bf84d2ce69 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,16 @@ +Release type: minor + +This release moves the `Connection` implementation outside the relay package, +allowing it to be used for general-purpose cursor pagination. + +The following now can be imported from `strawberry.pagination`: + +- `Connection` - base generic class for implementing connections +- `ListConnection` - a limit-offset implementation of the connection +- `connection` - field decorator for creating connections + +Those can still be used together with the relay package, but importing from it +is now deprecated. + +You can read more about connections in the +[Strawberry Connection Docs](https://strawberry.rocks/docs/guides/pagination/connections). diff --git a/docs/errors/relay-wrong-annotation.md b/docs/errors/connection-wrong-annotation.md similarity index 100% rename from docs/errors/relay-wrong-annotation.md rename to docs/errors/connection-wrong-annotation.md diff --git a/docs/errors/relay-wrong-resolver-annotation.md b/docs/errors/connection-wrong-resolver-annotation.md similarity index 69% rename from docs/errors/relay-wrong-resolver-annotation.md rename to docs/errors/connection-wrong-resolver-annotation.md index 80c732dd95..1a99d767be 100644 --- a/docs/errors/relay-wrong-resolver-annotation.md +++ b/docs/errors/connection-wrong-resolver-annotation.md @@ -1,13 +1,13 @@ --- -title: Relay wrong resolver annotation Error +title: Connection wrong resolver annotation Error --- -# Relay wrong resolver annotation error +# Connection wrong resolver annotation error ## Description -This error is thrown when a field on a relay connection was defined with a -resolver that returns something that is not compatible with pagination. +This error is thrown when a field on a connection was defined with a resolver +that returns something that is not compatible with pagination. For example, the following code would throw this error: @@ -15,19 +15,19 @@ For example, the following code would throw this error: from typing import Any import strawberry -from strawberry import relay +from strawberry.pagination import connection @strawberry.type -class MyType(relay.Node): ... +class MyType(Node): ... @strawberry.type class Query: - @relay.connection(relay.Connection[MyType]) + @connection(Connection[MyType]) def some_connection_returning_mytype(self) -> MyType: ... - @relay.connection(relay.Connection[MyType]) + @connection(Connection[MyType]) def some_connection_returning_any(self) -> Any: ... ``` @@ -53,22 +53,22 @@ For example: from typing import Any import strawberry -from strawberry import relay +from strawberry.pagination import connection, Connection @strawberry.type -class MyType(relay.Node): ... +class MyType: ... @strawberry.type class Query: - @relay.connection(relay.Connection[MyType]) + @connection(Connection[MyType]) def some_connection(self) -> Iterable[MyType]: ... ``` Note that if you are returning a type different than the connection type, you will need to subclass the connection type and override its `resolve_node` - method to convert it to the correct type, as explained in the [relay - guide](../guides/relay). + method to convert it to the correct type, as explained in the [pagination + guide](../guides/pagination). diff --git a/docs/guides/pagination/connections.md b/docs/guides/pagination/connections.md index 37c6a89f3a..44aab943e1 100644 --- a/docs/guides/pagination/connections.md +++ b/docs/guides/pagination/connections.md @@ -1,14 +1,14 @@ --- -title: Pagination - Implementing the Relay Connection Specification +title: Pagination - Implementing the Connection Specification --- -# Implementing the Relay Connection Specification +# Implementing the Connection Specification We naively implemented cursor based pagination in the [previous tutorial](./cursor-based.md). To ensure a consistent implementation of this pattern, the Relay project has a formal -[specification](https://relay.dev/graphql/connections.htm) you can follow for -building GraphQL APIs which use a cursor based connection pattern. +[connection specification](https://relay.dev/graphql/connections.htm), and +strawberry provides a `Connection` generic type to implement it. By the end of this tutorial, we should be able to return a connection of users when requested. @@ -74,348 +74,270 @@ query getUsers { -## Connections +## Connection A Connection represents a paginated relationship between two entities. This pattern is used when the relationship itself has attributes. For example, we might have a connection of users to represent a paginated list of users. -Let us define a Connection type which takes in a Generic ObjectType. +Strawberry provides a `relay.Connection` class, which contains the basics of +what you need to implement the specification, but don't implement any pagination +logic. To use it, all you need to do is to subclass it and implement its +abstract `resolve_connection` classmethod. -```py -# example.py + -from typing import Generic, TypeVar +For basic cases, Strawberry provides a `ListConnection` which already implements +a `resolve_connection`, which we are going to look in the next section. -import strawberry - - -GenericType = TypeVar("GenericType") - - -@strawberry.type -class Connection(Generic[GenericType]): - page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." - ) - - edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." - ) -``` - -Connections must have atleast two fields: `edges` and `page_info`. - -The `page_info` field contains metadata about the connection. Following the -Relay specification, we can define a `PageInfo` type like this: - -```py line=22-38 -# example.py + -from typing import Generic, TypeVar +Lets look at an example for a connection of users: +```python import strawberry - - -GenericType = TypeVar("GenericType") - - -@strawberry.type -class Connection(Generic[GenericType]): - page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." - ) - - edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." - ) +from strawberry.pagination import Connection, Edge, to_base64 @strawberry.type -class PageInfo: - has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" - ) - - has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" - ) - - start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." - ) - - end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." - ) -``` - -You can read more about the `PageInfo` type at: - -- https://graphql.org/learn/pagination/#pagination-and-edges -- https://relay.dev/graphql/connections.htm - -The `edges` field must return a list type that wraps an edge type. - -Following the Relay specification, let us define an Edge that takes in a generic -ObjectType. - -```py line=41-49 -# example.py - -from typing import Generic, TypeVar - -import strawberry +class UserConnection(Connection[User]): + @classmethod + def resolve_connection( + cls, + nodes: Iterable[Fruit], + *, + info: Optional[Info] = None, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + ): + # NOTE: This is a showcase implementation and is far from + # being optimal performance wise + edges_mapping = { + to_base64("cursor-name", n.name): Edge( + node=n, + cursor=to_base64("cursor-name", n.name), + ) + for n in sorted(nodes, key=lambda f: f.name) + } + edges = list(edges_mapping.values()) + first_edge = edges[0] if edges else None + last_edge = edges[-1] if edges else None + if after is not None: + after_edge_idx = edges.index(edges_mapping[after]) + edges = [e for e in edges if edges.index(e) > after_edge_idx] -GenericType = TypeVar("GenericType") + if before is not None: + before_edge_idx = edges.index(edges_mapping[before]) + edges = [e for e in edges if edges.index(e) < before_edge_idx] + if first is not None: + edges = edges[:first] -@strawberry.type -class Connection(Generic[GenericType]): - page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." - ) + if last is not None: + edges = edges[-last:] - edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." - ) + return cls( + edges=edges, + page_info=strawberry.relay.PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=( + first_edge is not None and bool(edges) and edges[0] != first_edge + ), + has_next_page=( + last_edge is not None and bool(edges) and edges[-1] != last_edge + ), + ), + ) @strawberry.type -class PageInfo: - has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" - ) +class Query: + @connection(UserConnection) + def get_users(self) -> Iterable[User]: + # This can be a database query, a generator, an async generator, etc + return some_function_that_returns_users() +``` - has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" - ) +This would generate a schema like this: - start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." - ) +```graphql +type PageInfo { + hasNextPage: Boolean! + hasPreviousPage: Boolean! + startCursor: String + endCursor: String +} - end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." - ) +type User { + id: ID! + name: String! + occupation: String! + age: Int! +} +type UserEdge { + cursor: String! + node: User! +} -@strawberry.type -class Edge(Generic[GenericType]): - node: GenericType = strawberry.field(description="The item at the end of the edge.") +type UserConnection { + pageInfo: PageInfo! + edges: [UserEdge!]! +} - cursor: str = strawberry.field(description="A cursor for use in pagination.") +type Query { + getUsers( + first: Int = null + last: Int = null + before: String = null + after: String = null + ): UserConnection! +} ``` -EdgeTypes must have atleast two fields - `cursor` and `node`. Each edge has it's -own cursor and item (represented by the `node` field). +## ListConnection -Now that we have the types needed to implement pagination using Relay -Connections, let us use them to paginate a list of users. For simplicity's sake, -let our dataset be a list of dictionaries. +Strawberry also provides `ListConnection`, a subclass of `Connection` that +implementes a limit/offset pagination algorithm by using slices. -```py line=7-32 -# example.py +If a limit/offset pagination is enough for your needs, the above example can be +simplified to use `ListConnection` to a basic resolver that returns one of: -from typing import Generic, TypeVar +- `List[]` +- `Iterator[]` +- `Iterable[]` +- `AsyncIterator[]` +- `AsyncIterable[]` +- `Generator[, Any, Any]` +- `AsyncGenerator[, Any]` -import strawberry +For example: -user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42, - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20, - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19, - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20, - }, -] - - -GenericType = TypeVar("GenericType") +```python +import strawberry +from strawberry.pagination import Connection, Edge, to_base64 @strawberry.type -class Connection(Generic[GenericType]): - page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." - ) +class Query: + @connection(ListConnection[User]) + def get_users(self) -> Iterable[User]: + # This can be a database query, a generator, an async generator, etc + return some_function_that_returns_users() +``` - edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." - ) +Because the implementation will use a slice to paginate the data, that means you +can override what the slice does by customizing the `__getitem__` method of the +object returned by your nodes resolver. +For example, when working with `Django`, `resolve_nodes` can return a +`QuerySet`, meaning that the slice on it will translate to a `LIMIT`/`OFFSET` in +the SQL query, making it fetch only the data that is needed from the database. -@strawberry.type -class PageInfo: - has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" - ) +Also note that if that object doesn't have a `__getitem__` attribute, it will +use `itertools.islice` to paginate it, meaning that when a generator is being +resolved it will only generate as much results as needed for the given +pagination, the worst case scenario being the last results needing to be +returned. - has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" - ) +## Custom Connection Arguments - start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." - ) +By default the connection will automatically insert some arguments for it to be +able to paginate the results. Those are: - end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." - ) +- `before`: Returns the items in the list that come before the specified cursor +- `after`: Returns the items in the list that come after the " "specified cursor +- `first`: Returns the first n items from the list +- `last`: Returns the items in the list that come after the " "specified cursor +You can still define extra arguments to be used by your own resolver or custom +pagination logic, and those will be merged together. For example, suppose we +want to return the pagination of all users whose name starts with a given +string. We could do that like this: +```python @strawberry.type -class Edge(Generic[GenericType]): - node: GenericType = strawberry.field(description="The item at the end of the edge.") - - cursor: str = strawberry.field(description="A cursor for use in pagination.") +class Query: + @connection(ListConnection[User]) + def get_users(self, name_starswith: str) -> Iterable[User]: + return some_function_that_returns_users(name_startswith=name_startswith) ``` -Now is a good time to think of what we could use as a cursor for our dataset. -Our cursor needs to be an opaque value, which doesn't usually change over time. -It makes sense to use base64 encoded IDs of users as our cursor, as they fit -both criteria. - - +This will generate a `Query` like this: -While working with Connections, it is a convention to base64-encode cursors. It -provides a unified interface to the end user. API clients need not bother about -the type of data to paginate, and can pass unique IDs during pagination. It also -makes the cursors opaque. - - +```graphql +type Query { + getUsers( + nameStartswith: String! + first: Int = null + last: Int = null + before: String = null + after: String = null + ): UserConnection! +} +``` -Let us define a couple of helper functions to encode and decode cursors as -follows: +## Converting the node to its proper type when resolving the connection -```py line=3,35-43 -# example.py +The connection expects that the resolver will return a list of objects that is a +subclass of its `NodeType`. But there may be situations where you are resolving +something that needs to be converted to the proper type, like an ORM model. -from base64 import b64encode, b64decode -from typing import Generic, TypeVar +In this case you can subclass the `Connection`/`ListConnection` and provide a +custom `resolve_node` method to it, which by default returns the node as is. For +example: +```python import strawberry +from strawberry.pagination import ListConnection, connection -user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42, - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20, - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19, - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20, - }, -] - - -def encode_user_cursor(id: int) -> str: - """ - Encodes the given user ID into a cursor. - - :param id: The user ID to encode. - - :return: The encoded cursor. - """ - return b64encode(f"user:{id}".encode("ascii")).decode("ascii") - - -def decode_user_cursor(cursor: str) -> int: - """ - Decodes the user ID from the given cursor. - - :param cursor: The cursor to decode. - - :return: The decoded user ID. - """ - cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") - return int(cursor_data.split(":")[1]) - - -GenericType = TypeVar("GenericType") +from db.models import UserModel @strawberry.type -class Connection(Generic[GenericType]): - page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." - ) - - edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." - ) +class User: + id: int + name: str @strawberry.type -class PageInfo: - has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" - ) - - has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" - ) - - start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." - ) - - end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." - ) +class UserConnection(ListConnection[User]): + @classmethod + def resolve_node(cls, node: UserModel, *, info, **kwargs) -> User: + return User( + id=node.id, + name=node.name, + ) @strawberry.type -class Edge(Generic[GenericType]): - node: GenericType = strawberry.field(description="The item at the end of the edge.") - - cursor: str = strawberry.field(description="A cursor for use in pagination.") +class Query: + @connection(UserConnection) + def get_users(self, info: strawberry.Info) -> Iterable[UserDB]: + return UserDB.objects.all() ``` -Let us define a `get_users` field which returns a connection of users, as well -as an `UserType`. Let us also plug our query into a schema. +The main advantage of this approach instead of converting it inside the custom +resolver is that the `Connection` will paginate the `QuerySet` first, which in +case of Django will make sure that only the paginated results are fetched from +the database. After that, the `resolve_node` function will be called for each +result to retrieve the correct object for it. + +We used Django for this example, but the same applies to any other other similar +use case, like SQLAlchemy, etc. -```python line=104-174 -# example.py +## Full working example -from base64 import b64encode, b64decode -from typing import List, Optional, Generic, TypeVar +Here is a full working example of a connection of users which you can play with: +```python +from typing import Iterable import strawberry +from strawberry.pagination import ListConnection, connection user_data = [ { @@ -445,152 +367,28 @@ user_data = [ ] -def encode_user_cursor(id: int) -> str: - """ - Encodes the given user ID into a cursor. - - :param id: The user ID to encode. - - :return: The encoded cursor. - """ - return b64encode(f"user:{id}".encode("ascii")).decode("ascii") - - -def decode_user_cursor(cursor: str) -> int: - """ - Decodes the user ID from the given cursor. - - :param cursor: The cursor to decode. - - :return: The decoded user ID. - """ - cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") - return int(cursor_data.split(":")[1]) - - -GenericType = TypeVar("GenericType") - - -@strawberry.type -class Connection(Generic[GenericType]): - page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." - ) - - edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." - ) - - -@strawberry.type -class PageInfo: - has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" - ) - - has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" - ) - - start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." - ) - - end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." - ) - - -@strawberry.type -class Edge(Generic[GenericType]): - node: GenericType = strawberry.field(description="The item at the end of the edge.") - - cursor: str = strawberry.field(description="A cursor for use in pagination.") - - @strawberry.type class User: - id: int = strawberry.field(description="The id of the user.") - - name: str = strawberry.field(description="The name of the user.") - - occupation: str = strawberry.field(description="The occupation of the user.") - - age: int = strawberry.field(description="The age of the user.") + id: int + name: str + occupation: str + age: int @strawberry.type class Query: - @strawberry.field(description="Get a list of users.") - def get_users( - self, first: int = 2, after: Optional[str] = None - ) -> Connection[User]: - if after is not None: - # decode the user ID from the given cursor. - user_id = decode_user_cursor(cursor=after) - else: - # no cursor was given (this happens usually when the - # client sends a query for the first time). - user_id = 0 - - # filter the user data, going through the next set of results. - filtered_data = list(filter(lambda user: user["id"] > user_id, user_data)) - - # slice the relevant user data (Here, we also slice an - # additional user instance, to prepare the next cursor). - sliced_users = filtered_data[: first + 1] - - if len(sliced_users) > first: - # calculate the client's next cursor. - last_user = sliced_users.pop(-1) - next_cursor = encode_user_cursor(id=last_user["id"]) - has_next_page = True - else: - # We have reached the last page, and - # don't have the next cursor. - next_cursor = None - has_next_page = False - - # We know that we have items in the - # previous page window if the initial user ID - # was not the first one. - has_previous_page = user_id > 0 - - # build user edges. - edges = [ - Edge( - node=User(**user), - cursor=encode_user_cursor(id=user["id"]), + @connection(ListConnection[User]) + def get_users(self) -> Iterable[User]: + return [ + User( + id=user["id"], + name=user["name"], + occupation=user["occupation"], + age=user["age"], ) - for user in sliced_users + for user in user_data ] - if edges: - # we have atleast one edge. Get the cursor - # of the first edge we have. - start_cursor = edges[0].cursor - else: - # We have no edges to work with. - start_cursor = None - - if len(edges) > 1: - # We have atleast 2 edges. Get the cursor - # of the last edge we have. - end_cursor = edges[-1].cursor - else: - # We don't have enough edges to work with. - end_cursor = None - - return Connection( - edges=edges, - page_info=PageInfo( - has_next_page=has_next_page, - has_previous_page=has_previous_page, - start_cursor=start_cursor, - end_cursor=end_cursor, - ), - ) - schema = strawberry.Schema(query=Query) ``` diff --git a/docs/guides/pagination/overview.md b/docs/guides/pagination/overview.md index 8b309575f8..9938436415 100644 --- a/docs/guides/pagination/overview.md +++ b/docs/guides/pagination/overview.md @@ -136,8 +136,8 @@ consistent way to handle pagination. Strawberry provides a cursor based pagination implementing the -[relay spec](https://relay.dev/docs/guides/graphql-server-specification/). You -can read more about it in the [relay](../relay) page. +[connection](https://relay.dev/docs/guides/graphql-server-specification/). You +can read more about it in the [connections](./connections) page. @@ -203,4 +203,4 @@ Let us look at how we can implement pagination in GraphQL. - [Implementing Offset Pagination](./offset-based.md) - [Implementing Cursor Pagination](./cursor-based.md) -- [Implementing the Relay Connection Specification](./connections.md) +- [Implementing the Connection Specification](./connections.md) diff --git a/docs/guides/relay.md b/docs/guides/relay.md index 394583abcf..2b78cea8da 100644 --- a/docs/guides/relay.md +++ b/docs/guides/relay.md @@ -82,14 +82,15 @@ string `:`. In the example above, the `Fruit` with a code of -Now we can expose it in the schema for retrieval and pagination like: +Now we can expose it in the schema for retrieval and +[pagination](./pagination/connections) like: ```python @strawberry.type class Query: node: relay.Node = relay.node() - @relay.connection(relay.ListConnection[Fruit]) + @connection(ListConnection[Fruit]) def fruits(self) -> Iterable[Fruit]: # This can be a database query, a generator, an async generator, etc return all_fruits.values() @@ -179,16 +180,6 @@ query { } ``` -The connection resolver for `relay.ListConnection` should return one of those: - -- `List[]` -- `Iterator[]` -- `Iterable[]` -- `AsyncIterator[]` -- `AsyncIterable[]` -- `Generator[, Any, Any]` -- `AsyncGenerator[, Any]` - ### The node field As demonstrated above, the `Node` field can be used to retrieve/refetch any @@ -205,207 +196,14 @@ It can be defined in the `Query` objects in 4 ways: - `node: List[Optional[Node]]`: The same as `List[Node]`, but the returned list can contain `null` values if the given objects don't exist. -### Custom connection pagination - -The default `relay.Connection` class don't implement any pagination logic, and -should be used as a base class to implement your own pagination logic. All you -need to do is implement the `resolve_connection` classmethod. - -The integration provides `relay.ListConnection`, which implements a limit/offset -approach to paginate the results. This is a basic approach and might be enough -for most use cases. - - - -`relay.ListConnection` implementes the limit/offset by using slices. That means -that you can override what the slice does by customizing the `__getitem__` -method of the object returned by your nodes resolver. - -For example, when working with `Django`, `resolve_nodes` can return a -`QuerySet`, meaning that the slice on it will translate to a `LIMIT`/`OFFSET` in -the SQL query, which will fetch only the data that is needed from the database. - -Also note that if that object doesn't have a `__getitem__` attribute, it will -use `itertools.islice` to paginate it, meaning that when a generator is being -resolved it will only generate as much results as needed for the given -pagination, the worst case scenario being the last results needing to be -returned. - - - -Now, suppose we want to implement a custom cursor-based pagination for our -previous example. We can do something like this: - -```python -import strawberry -from strawberry import relay - - -@strawberry.type -class FruitCustomPaginationConnection(relay.Connection[Fruit]): - @classmethod - def resolve_connection( - cls, - nodes: Iterable[Fruit], - *, - info: Optional[Info] = None, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - ): - # NOTE: This is a showcase implementation and is far from - # being optimal performance wise - edges_mapping = { - relay.to_base64("fruit_name", n.name): relay.Edge( - node=n, - cursor=relay.to_base64("fruit_name", n.name), - ) - for n in sorted(nodes, key=lambda f: f.name) - } - edges = list(edges_mapping.values()) - first_edge = edges[0] if edges else None - last_edge = edges[-1] if edges else None - - if after is not None: - after_edge_idx = edges.index(edges_mapping[after]) - edges = [e for e in edges if edges.index(e) > after_edge_idx] - - if before is not None: - before_edge_idx = edges.index(edges_mapping[before]) - edges = [e for e in edges if edges.index(e) < before_edge_idx] - - if first is not None: - edges = edges[:first] - - if last is not None: - edges = edges[-last:] - - return cls( - edges=edges, - page_info=strawberry.relay.PageInfo( - start_cursor=edges[0].cursor if edges else None, - end_cursor=edges[-1].cursor if edges else None, - has_previous_page=( - first_edge is not None and bool(edges) and edges[0] != first_edge - ), - has_next_page=( - last_edge is not None and bool(edges) and edges[-1] != last_edge - ), - ), - ) - - -@strawberry.type -class Query: - @relay.connection(FruitCustomPaginationConnection) - def fruits(self) -> Iterable[Fruit]: - # This can be a database query, a generator, an async generator, etc - return all_fruits.values() -``` - -In the example above we specialized the `FruitCustomPaginationConnection` by -inheriting it from `relay.Connection[Fruit]`. We could still keep it generic by -inheriting it from `relay.Connection[relay.NodeType]` and then specialize it -when defining the field, making it possible to use our custom pagination logic -with more than one type. - -### Custom connection arguments - -By default the connection will automatically insert some arguments for it to be -able to paginate the results. Those are: - -- `before`: Returns the items in the list that come before the specified cursor -- `after`: Returns the items in the list that come after the " "specified cursor -- `first`: Returns the first n items from the list -- `last`: Returns the items in the list that come after the " "specified cursor - -You can still define extra arguments to be used by your own resolver or custom -pagination logic. For example, suppose we want to return the pagination of all -fruits whose name starts with a given string. We could do that like this: - -```python -@strawberry.type -class Query: - @relay.connection(relay.ListConnection[Fruit]) - def fruits_with_filter( - self, - info: strawberry.Info, - name_endswith: str, - ) -> Iterable[Fruit]: - for f in fruits.values(): - if f.name.endswith(name_endswith): - yield f -``` - -This will generate a schema like this: - -```graphql -type Query { - fruitsWithFilter( - nameEndswith: String! - before: String = null - after: String = null - first: Int = null - last: Int = null - ): FruitConnection! -} -``` - -### Convert the node to its proper type when resolving the connection - -The connection expects that the resolver will return a list of objects that is a -subclass of its `NodeType`. But there may be situations where you are resolving -something that needs to be converted to the proper type, like an ORM model. - -In this case you can subclass the `relay.Connection`/`relay.ListConnection` and -provide a custom `resolve_node` method to it, which by default returns the node -as is. For example: - -```python -import strawberry -from strawberry import relay - -from db import models - - -@strawberry.type -class Fruit(relay.Node): - code: relay.NodeID[int] - name: str - weight: float - - -@strawberry.type -class FruitDBConnection(relay.ListConnection[Fruit]): - @classmethod - def resolve_node(cls, node: FruitDB, *, info: strawberry.Info, **kwargs) -> Fruit: - return Fruit( - code=node.code, - name=node.name, - weight=node.weight, - ) - - -@strawberry.type -class Query: - @relay.connection(FruitDBConnection) - def fruits_with_filter( - self, - info: strawberry.Info, - name_endswith: str, - ) -> Iterable[models.Fruit]: - return models.Fruit.objects.filter(name__endswith=name_endswith) -``` +### The connection field -The main advantage of this approach instead of converting it inside the custom -resolver is that the `Connection` will paginate the `QuerySet` first, which in -case of django will make sure that only the paginated results are fetched from -the database. After that, the `resolve_node` function will be called for each -result to retrieve the correct object for it. +The connection field, although defined in the relay specifications, is actually +a well-known pattern in GraphQL for getting paginated results. -We used django for this example, but the same applies to any other other similar -use case, like SQLAlchemy, etc. +Because of that, it is implemented in Strawberry as a generic pagination +solution, which you can read more about in the +[pagination](./pagination/connections) page. ### The GlobalID scalar diff --git a/strawberry/pagination/__init__.py b/strawberry/pagination/__init__.py new file mode 100644 index 0000000000..cf7b142207 --- /dev/null +++ b/strawberry/pagination/__init__.py @@ -0,0 +1,21 @@ +from .fields import ConnectionExtension, connection +from .types import ( + Connection, + Edge, + ListConnection, + NodeType, + PageInfo, +) +from .utils import from_base64, to_base64 + +__all__ = [ + "Connection", + "ConnectionExtension", + "Edge", + "ListConnection", + "NodeType", + "PageInfo", + "connection", + "from_base64", + "to_base64", +] diff --git a/strawberry/pagination/exceptions.py b/strawberry/pagination/exceptions.py new file mode 100644 index 0000000000..6e9c0bcc74 --- /dev/null +++ b/strawberry/pagination/exceptions.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import cached_property +from typing import TYPE_CHECKING, Optional, Type, cast + +from strawberry.exceptions.exception import StrawberryException +from strawberry.exceptions.utils.source_finder import SourceFinder + +if TYPE_CHECKING: + from strawberry.exceptions.exception_source import ExceptionSource + from strawberry.types.fields.resolver import StrawberryResolver + + +class ConnectionWrongAnnotationError(StrawberryException): + def __init__(self, field_name: str, cls: Type) -> None: + self.cls = cls + self.field_name = field_name + + self.message = ( + f'Wrong annotation used on field "{field_name}". It should be ' + 'annotated with a "Connection" subclass.' + ) + self.rich_message = ( + f"Wrong annotation for field `[underline]{self.field_name}[/]`" + ) + self.suggestion = ( + "To fix this error you can add a valid annotation, " + f"like [italic]`{self.field_name}: Connection[{cls}]` " + f"or [italic]`@connection(Connection[{cls}])`" + ) + self.annotation_message = "connection wrong annotation" + + super().__init__(self.message) + + @cached_property + def exception_source(self) -> Optional[ExceptionSource]: + if self.cls is None: + return None # pragma: no cover + + source_finder = SourceFinder() + return source_finder.find_class_attribute_from_object(self.cls, self.field_name) + + +class ConnectionWrongResolverAnnotationError(StrawberryException): + def __init__(self, field_name: str, resolver: StrawberryResolver) -> None: + self.function = resolver.wrapped_func + self.field_name = field_name + + self.message = ( + f'Wrong annotation used on "{field_name}" resolver. ' + "It should be return an iterable or async iterable object." + ) + self.rich_message = ( + f"Wrong annotation used on `{field_name}` resolver. " + "It should be return an `iterable` or `async iterable` object." + ) + self.suggestion = ( + "To fix this error you can annootate your resolver to return " + "one of the following options: `List[]`, " + "`Iterator[]`, `Iterable[]`, " + "`AsyncIterator[]`, `AsyncIterable[]`, " + "`Generator[, Any, Any]` and " + "`AsyncGenerator[, Any]`." + ) + self.annotation_message = "connection wrong resolver annotation" + + super().__init__(self.message) + + @cached_property + def exception_source(self) -> Optional[ExceptionSource]: + if self.function is None: + return None # pragma: no cover + + source_finder = SourceFinder() + return source_finder.find_function_from_object(cast(Callable, self.function)) + + +__all__ = [ + "ConnectionWrongAnnotationError", + "ConnectionWrongResolverAnnotationError", +] diff --git a/strawberry/pagination/fields.py b/strawberry/pagination/fields.py new file mode 100644 index 0000000000..3e58c6a0a3 --- /dev/null +++ b/strawberry/pagination/fields.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import dataclasses +import inspect +from collections.abc import AsyncIterable +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + ForwardRef, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Type, + Union, + cast, + overload, +) +from typing_extensions import get_origin + +from strawberry.annotation import StrawberryAnnotation +from strawberry.extensions.field_extension import ( + AsyncExtensionResolver, + FieldExtension, + SyncExtensionResolver, +) +from strawberry.types.arguments import StrawberryArgument +from strawberry.types.field import _RESOLVER_TYPE, StrawberryField +from strawberry.types.lazy_type import LazyType +from strawberry.utils.typing import eval_type, is_generic_alias + +from .exceptions import ( + ConnectionWrongAnnotationError, + ConnectionWrongResolverAnnotationError, +) +from .types import Connection, NodeIterableType, NodeType + +if TYPE_CHECKING: + from typing_extensions import Literal + + from strawberry.permission import BasePermission + from strawberry.types.info import Info + + +class ConnectionExtension(FieldExtension): + connection_type: Type[Connection] + + def apply(self, field: StrawberryField) -> None: + field.arguments = [ + *field.arguments, + StrawberryArgument( + python_name="before", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[str]), + description=( + "Returns the items in the list that come before the " + "specified cursor." + ), + default=None, + ), + StrawberryArgument( + python_name="after", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[str]), + description=( + "Returns the items in the list that come after the " + "specified cursor." + ), + default=None, + ), + StrawberryArgument( + python_name="first", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[int]), + description="Returns the first n items from the list.", + default=None, + ), + StrawberryArgument( + python_name="last", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[int]), + description=( + "Returns the items in the list that come after the " + "specified cursor." + ), + default=None, + ), + ] + + f_type = field.type + + if isinstance(f_type, LazyType): + f_type = f_type.resolve_type() + field.type = f_type + + type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type + if not isinstance(type_origin, type) or not issubclass(type_origin, Connection): + raise ConnectionWrongAnnotationError(field.name, cast(type, field.origin)) + + assert field.base_resolver + # TODO: We are not using resolver_type.type because it will call + # StrawberryAnnotation.resolve, which will strip async types from the + # type (i.e. AsyncGenerator[Fruit] will become Fruit). This is done there + # for subscription support, but we can't use it here. Maybe we can refactor + # this in the future. + resolver_type = field.base_resolver.signature.return_annotation + if isinstance(resolver_type, str): + resolver_type = ForwardRef(resolver_type) + if isinstance(resolver_type, ForwardRef): + resolver_type = eval_type( + resolver_type, + field.base_resolver._namespace, + None, + ) + + origin = get_origin(resolver_type) + if origin is None or not issubclass( + origin, (Iterator, Iterable, AsyncIterator, AsyncIterable) + ): + raise ConnectionWrongResolverAnnotationError( + field.name, field.base_resolver + ) + + self.connection_type = cast(Type[Connection], field.type) + + def resolve( + self, + next_: SyncExtensionResolver, + source: Any, + info: Info, + *, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs: Any, + ) -> Any: + assert self.connection_type is not None + return self.connection_type.resolve_connection( + cast(Iterable, next_(source, info, **kwargs)), + info=info, + before=before, + after=after, + first=first, + last=last, + ) + + async def resolve_async( + self, + next_: AsyncExtensionResolver, + source: Any, + info: Info, + *, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs: Any, + ) -> Any: + assert self.connection_type is not None + nodes = next_(source, info, **kwargs) + # nodes might be an AsyncIterable/AsyncIterator + # In this case we don't await for it + if inspect.isawaitable(nodes): + nodes = await nodes + + resolved = self.connection_type.resolve_connection( + cast(Iterable, nodes), + info=info, + before=before, + after=after, + first=first, + last=last, + ) + + # If nodes was an AsyncIterable/AsyncIterator, resolve_connection + # will return a coroutine which we need to await + if inspect.isawaitable(resolved): + resolved = await resolved + return resolved + + +@overload +def connection( + graphql_type: Optional[Type[Connection[NodeType]]] = None, + *, + resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + init: Literal[True] = True, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + extensions: List[FieldExtension] = (), # type: ignore +) -> Any: ... + + +@overload +def connection( + graphql_type: Optional[Type[Connection[NodeType]]] = None, + *, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + extensions: List[FieldExtension] = (), # type: ignore +) -> StrawberryField: ... + + +def connection( + graphql_type: Optional[Type[Connection[NodeType]]] = None, + *, + resolver: Optional[_RESOLVER_TYPE[Any]] = None, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + extensions: List[FieldExtension] = (), # type: ignore + # This init parameter is used by pyright to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotate a property or a method to create a connection field. + + Connections are mostly used for pagination purposes. This decorator + helps creating a complete connection endpoint that provides default arguments + and has a default implementation for the connection slicing. + + Note that when setting a resolver to this field, it is expected for this + resolver to return an iterable of the expected node type, not the connection + itself. That iterable will then be paginated accordingly. So, the main use + case for this is to provide a filtered iterable of nodes by using some custom + filter arguments. + + Args: + graphql_type: The type of the nodes in the connection. This is used to + determine the type of the edges and the node field in the connection. + resolver: The resolver for the connection. This is expected to return an + iterable of the expected node type. + name: The GraphQL name of the field. + is_subscription: Whether the field is a subscription. + description: The GraphQL description of the field. + permission_classes: The permission classes to apply to the field. + deprecation_reason: The deprecation reason of the field. + default: The default value of the field. + default_factory: The default factory of the field. + metadata: The metadata of the field. + directives: The directives to apply to the field. + extensions: The extensions to apply to the field. + init: Used only for type checking purposes. + + Examples: + Annotating something like this: + + ```python + @strawberry.type + class X: + some_node: Connection[SomeType] = connection( + resolver=get_some_nodes, + description="ABC", + ) + + @connection(Connection[SomeType], description="ABC") + def get_some_nodes(self, age: int) -> Iterable[SomeType]: ... + ``` + + Will produce a query like this: + + ```graphql + query { + someNode ( + before: String + after: String + first: String + after: String + age: Int + ) { + totalCount + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + id + ... + } + } + } + } + ``` + + .. _Connections: + https://relay.dev/graphql/connections.htm + + """ + f = StrawberryField( + python_name=None, + graphql_name=name, + description=description, + type_annotation=StrawberryAnnotation.from_annotation(graphql_type), + is_subscription=is_subscription, + permission_classes=permission_classes or [], + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives or (), + extensions=[*extensions, ConnectionExtension()], + ) + if resolver is not None: + f = f(resolver) + return f + + +__all__ = ["ConnectionExtension", "connection"] diff --git a/strawberry/pagination/types.py b/strawberry/pagination/types.py new file mode 100644 index 0000000000..3e8c4fec8c --- /dev/null +++ b/strawberry/pagination/types.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import itertools +import sys +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Generic, + Iterable, + Iterator, + List, + Optional, + Sequence, + TypeVar, + Union, + cast, +) +from typing_extensions import Self, TypeAlias + +from strawberry.types.base import ( + StrawberryContainer, + get_object_definition, +) +from strawberry.types.field import field +from strawberry.types.info import Info # noqa: TCH001 +from strawberry.types.object_type import type +from strawberry.utils.aio import aenumerate, aislice +from strawberry.utils.inspect import in_async_context + +from .utils import ( + SliceMetadata, + should_resolve_list_connection_edges, + to_base64, +) + +if TYPE_CHECKING: + from strawberry.utils.await_maybe import AwaitableOrValue + +NodeType = TypeVar("NodeType") +NodeIterableType: TypeAlias = Union[ + Iterator[NodeType], + Iterable[NodeType], + AsyncIterator[NodeType], + AsyncIterable[NodeType], +] + +PREFIX = "arrayconnection" + + +@type(description="Information to aid in pagination.") +class PageInfo: + """Information to aid in pagination. + + Attributes: + has_next_page: + When paginating forwards, are there more items? + has_previous_page: + When paginating backwards, are there more items? + start_cursor: + When paginating backwards, the cursor to continue + end_cursor: + When paginating forwards, the cursor to continue + """ + + has_next_page: bool = field( + description="When paginating forwards, are there more items?", + ) + has_previous_page: bool = field( + description="When paginating backwards, are there more items?", + ) + start_cursor: Optional[str] = field( + description="When paginating backwards, the cursor to continue.", + ) + end_cursor: Optional[str] = field( + description="When paginating forwards, the cursor to continue.", + ) + + +@type(description="An edge in a connection.") +class Edge(Generic[NodeType]): + """An edge in a connection. + + Attributes: + cursor: + A cursor for use in pagination + node: + The item at the end of the edge + """ + + cursor: str = field(description="A cursor for use in pagination") + node: NodeType = field(description="The item at the end of the edge") + + @classmethod + def resolve_edge(cls, node: NodeType, *, cursor: Any = None) -> Self: + return cls(cursor=to_base64(PREFIX, cursor), node=node) + + +@type(description="A connection to a list of items.") +class Connection(Generic[NodeType]): + """A connection to a list of items. + + Attributes: + page_info: + Pagination data for this connection + edges: + Contains the nodes in this connection + + """ + + page_info: PageInfo = field(description="Pagination data for this connection") + edges: List[Edge[NodeType]] = field( + description="Contains the nodes in this connection" + ) + + @classmethod + def resolve_node(cls, node: Any, *, info: Info, **kwargs: Any) -> NodeType: + """The identity function for the node. + + This method is used to resolve a node of a different type to the + connection's `NodeType`. + + By default it returns the node itself, but subclasses can override + this to provide a custom implementation. + + Args: + node: + The resolved node which should return an instance of this + connection's `NodeType`. + info: + The strawberry execution info resolve the type name from. + **kwargs: + Additional arguments passed to the resolver. + + """ + return node + + @classmethod + def resolve_connection( + cls, + nodes: NodeIterableType[NodeType], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs: Any, + ) -> AwaitableOrValue[Self]: + """Resolve a connection from nodes. + + Subclasses must define this method to paginate nodes based + on `first`/`last`/`before`/`after` arguments. + + Args: + info: The strawberry execution info resolve the type name from. + nodes: An iterable/iteretor of nodes to paginate. + before: Returns the items in the list that come before the specified cursor. + after: Returns the items in the list that come after the specified cursor. + first: Returns the first n items from the list. + last: Returns the items in the list that come after the specified cursor. + kwargs: Additional arguments passed to the resolver. + + Returns: + The resolved `Connection` + + """ + raise NotImplementedError + + +@type(name="Connection", description="A connection to a list of items.") +class ListConnection(Connection[NodeType]): + """A connection to a list of items. + + Attributes: + page_info: + Pagination data for this connection + edges: + Contains the nodes in this connection + + """ + + page_info: PageInfo = field(description="Pagination data for this connection") + edges: List[Edge[NodeType]] = field( + description="Contains the nodes in this connection" + ) + + @classmethod + def resolve_connection( + cls, + nodes: NodeIterableType[NodeType], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs: Any, + ) -> AwaitableOrValue[Self]: + """Resolve a connection from the list of nodes. + + This uses the described Connection Pagination algorithm_ + + Args: + info: The strawberry execution info resolve the type name from. + nodes: An iterable/iteretor of nodes to paginate. + before: Returns the items in the list that come before the specified cursor. + after: Returns the items in the list that come after the specified cursor. + first: Returns the first n items from the list. + last: Returns the items in the list that come after the specified cursor. + kwargs: Additional arguments passed to the resolver. + + Returns: + The resolved `Connection` + + .. _Connection Pagination algorithm: + https://relay.dev/graphql/connections.htm#sec-Pagination-algorithm + """ + slice_metadata = SliceMetadata.from_arguments( + info, + before=before, + after=after, + first=first, + last=last, + ) + + type_def = get_object_definition(cls) + assert type_def + field_def = type_def.get_field("edges") + assert field_def + + field = field_def.resolve_type(type_definition=type_def) + while isinstance(field, StrawberryContainer): + field = field.of_type + + edge_class = cast(Edge[NodeType], field) + + if isinstance(nodes, (AsyncIterator, AsyncIterable)) and in_async_context(): + + async def resolver() -> Self: + try: + iterator = cast( + Union[AsyncIterator[NodeType], AsyncIterable[NodeType]], + cast(Sequence, nodes)[ + slice_metadata.start : slice_metadata.overfetch + ], + ) + except TypeError: + # TODO: Why mypy isn't narrowing this based on the if above? + assert isinstance(nodes, (AsyncIterator, AsyncIterable)) + iterator = aislice( + nodes, + slice_metadata.start, + slice_metadata.overfetch, + ) + + # The slice above might return an object that now is not async + # iterable anymore (e.g. an already cached django queryset) + if isinstance(iterator, (AsyncIterator, AsyncIterable)): + edges: List[Edge] = [ + edge_class.resolve_edge( + cls.resolve_node(v, info=info, **kwargs), + cursor=slice_metadata.start + i, + ) + async for i, v in aenumerate(iterator) + ] + else: + edges: List[Edge] = [ # type: ignore[no-redef] + edge_class.resolve_edge( + cls.resolve_node(v, info=info, **kwargs), + cursor=slice_metadata.start + i, + ) + for i, v in enumerate(iterator) + ] + + has_previous_page = slice_metadata.start > 0 + if ( + slice_metadata.expected is not None + and len(edges) == slice_metadata.expected + 1 + ): + # Remove the overfetched result + edges = edges[:-1] + has_next_page = True + elif slice_metadata.end == sys.maxsize: + # Last was asked without any after/before + assert last is not None + original_len = len(edges) + edges = edges[-last:] + has_next_page = False + has_previous_page = len(edges) != original_len + else: + has_next_page = False + + return cls( + edges=edges, + page_info=PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + ), + ) + + return resolver() + + try: + iterator = cast( + Union[Iterator[NodeType], Iterable[NodeType]], + cast(Sequence, nodes)[slice_metadata.start : slice_metadata.overfetch], + ) + except TypeError: + assert isinstance(nodes, (Iterable, Iterator)) + iterator = itertools.islice( + nodes, + slice_metadata.start, + slice_metadata.overfetch, + ) + + if not should_resolve_list_connection_edges(info): + return cls( + edges=[], + page_info=PageInfo( + start_cursor=None, + end_cursor=None, + has_previous_page=False, + has_next_page=False, + ), + ) + + edges = [ + edge_class.resolve_edge( + cls.resolve_node(v, info=info, **kwargs), + cursor=slice_metadata.start + i, + ) + for i, v in enumerate(iterator) + ] + + has_previous_page = slice_metadata.start > 0 + if ( + slice_metadata.expected is not None + and len(edges) == slice_metadata.expected + 1 + ): + # Remove the overfetched result + edges = edges[:-1] + has_next_page = True + elif slice_metadata.end == sys.maxsize: + # Last was asked without any after/before + assert last is not None + original_len = len(edges) + edges = edges[-last:] + has_next_page = False + has_previous_page = len(edges) != original_len + else: + has_next_page = False + + return cls( + edges=edges, + page_info=PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + ), + ) + + +__all__ = [ + "NodeIterableType", + "NodeType", + "PREFIX", + "Connection", + "Edge", + "PageInfo", + "ListConnection", +] diff --git a/strawberry/pagination/utils.py b/strawberry/pagination/utils.py new file mode 100644 index 0000000000..34af41f8ca --- /dev/null +++ b/strawberry/pagination/utils.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import base64 +import dataclasses +import sys +from typing import TYPE_CHECKING, Any, Tuple, Union +from typing_extensions import Self, assert_never + +from strawberry.types.base import StrawberryObjectDefinition +from strawberry.types.nodes import InlineFragment, Selection + +if TYPE_CHECKING: + from strawberry.types.info import Info + + +def from_base64(value: str) -> Tuple[str, str]: + """Parse the base64 encoded relay value. + + Args: + value: + The value to be parsed + + Returns: + A tuple of (TypeName, NodeID). + + Raises: + ValueError: + If the value is not in the expected format + + """ + try: + res = base64.b64decode(value.encode()).decode().split(":", 1) + except Exception as e: + raise ValueError(str(e)) from e + + if len(res) != 2: + raise ValueError(f"{res} expected to contain only 2 items") + + return res[0], res[1] + + +def to_base64(type_: Union[str, type, StrawberryObjectDefinition], node_id: Any) -> str: + """Encode the type name and node id to a base64 string. + + Args: + type_: + The GraphQL type, type definition or type name. + node_id: + The node id itself + + Returns: + A GlobalID, which is a string resulting from base64 encoding :. + + Raises: + ValueError: + If the value is not a valid GraphQL type or name + + """ + try: + if isinstance(type_, str): + type_name = type_ + elif isinstance(type_, StrawberryObjectDefinition): + type_name = type_.name + elif isinstance(type_, type): + type_name = type_.__strawberry_definition__.name # type:ignore + else: # pragma: no cover + assert_never(type_) + except Exception as e: + raise ValueError(f"{type_} is not a valid GraphQL type or name") from e + + return base64.b64encode(f"{type_name}:{node_id}".encode()).decode() + + +def should_resolve_list_connection_edges(info: Info) -> bool: + """Check if the user requested to resolve the `edges` field of a connection. + + Args: + info: + The strawberry execution info resolve the type name from + + Returns: + True if the user requested to resolve the `edges` field of a connection, False otherwise. + + """ + resolve_for_field_names = {"edges", "pageInfo"} + + def _check_selection(selection: Selection) -> bool: + """Recursively inspect the selection to check if the user requested to resolve the `edges` field. + + Args: + selection (Selection): The selection to check. + + Returns: + bool: True if the user requested to resolve the `edges` field of a connection, False otherwise. + """ + if ( + not isinstance(selection, InlineFragment) + and selection.name in resolve_for_field_names + ): + return True + if selection.selections: + return any( + _check_selection(selection) for selection in selection.selections + ) + return False + + for selection_field in info.selected_fields: + for selection in selection_field.selections: + if _check_selection(selection): + return True + return False + + +@dataclasses.dataclass +class SliceMetadata: + start: int + end: int + expected: int | None + + @property + def overfetch(self) -> int: + # Overfetch by 1 to check if we have a next result + return self.end + 1 if self.end != sys.maxsize else self.end + + @classmethod + def from_arguments( + cls, + info: Info, + *, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, + ) -> Self: + """Get the slice metadata to use on ListConnection.""" + from strawberry.pagination.types import PREFIX + + max_results = info.schema.config.relay_max_results + start = 0 + end: int | None = None + + if after: + after_type, after_parsed = from_base64(after) + if after_type != PREFIX: + raise TypeError("Argument 'after' contains a non-existing value.") + + start = int(after_parsed) + 1 + if before: + before_type, before_parsed = from_base64(before) + if before_type != PREFIX: + raise TypeError("Argument 'before' contains a non-existing value.") + end = int(before_parsed) + + if isinstance(first, int): + if first < 0: + raise ValueError("Argument 'first' must be a non-negative integer.") + + if first > max_results: + raise ValueError( + f"Argument 'first' cannot be higher than {max_results}." + ) + + if end is not None: + start = max(0, end - 1) + + end = start + first + if isinstance(last, int): + if last < 0: + raise ValueError("Argument 'last' must be a non-negative integer.") + + if last > max_results: + raise ValueError( + f"Argument 'last' cannot be higher than {max_results}." + ) + + if end is not None: + start = max(start, end - last) + else: + end = sys.maxsize + + if end is None: + end = start + max_results + + expected = end - start if end != sys.maxsize else None + + return cls( + start=start, + end=end, + expected=expected, + ) + + +__all__ = [ + "from_base64", + "to_base64", + "should_resolve_list_connection_edges", + "SliceMetadata", +] diff --git a/strawberry/relay/__init__.py b/strawberry/relay/__init__.py index c41e777b76..7ed73687b9 100644 --- a/strawberry/relay/__init__.py +++ b/strawberry/relay/__init__.py @@ -1,31 +1,54 @@ -from .fields import ConnectionExtension, NodeExtension, connection, node -from .types import ( +import warnings +from typing import Any + +from strawberry.pagination.fields import ConnectionExtension, connection +from strawberry.pagination.types import ( Connection, Edge, + ListConnection, + NodeType, + PageInfo, +) +from strawberry.pagination.utils import from_base64, to_base64 + +from .fields import NodeExtension, node +from .types import ( GlobalID, GlobalIDValueError, - ListConnection, Node, NodeID, - NodeType, - PageInfo, ) -from .utils import from_base64, to_base64 + +_DEPRECATIONS = { + "Connection": Connection, + "ConnectionExtension": ConnectionExtension, + "Edge": Edge, + "ListConnection": ListConnection, + "NodeType": NodeType, + "PageInfo": PageInfo, + "connection": connection, + "from_base64": from_base64, + "to_base64": to_base64, +} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATIONS: + warnings.warn( + f"{name} should be imported from strawberry.pagination", + DeprecationWarning, + stacklevel=2, + ) + return _DEPRECATIONS[name] + + raise AttributeError(f"module {__name__} has no attribute {name}") + __all__ = [ - "Connection", - "ConnectionExtension", - "Edge", "GlobalID", "GlobalIDValueError", - "ListConnection", "Node", "NodeExtension", "NodeID", - "NodeType", - "PageInfo", - "connection", - "from_base64", "node", - "to_base64", ] diff --git a/strawberry/relay/exceptions.py b/strawberry/relay/exceptions.py index 4150633d8c..b8e3d92122 100644 --- a/strawberry/relay/exceptions.py +++ b/strawberry/relay/exceptions.py @@ -1,15 +1,13 @@ from __future__ import annotations -from collections.abc import Callable from functools import cached_property -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, Optional, Type from strawberry.exceptions.exception import StrawberryException from strawberry.exceptions.utils.source_finder import SourceFinder if TYPE_CHECKING: from strawberry.exceptions.exception_source import ExceptionSource - from strawberry.types.fields.resolver import StrawberryResolver class NodeIDAnnotationError(StrawberryException): @@ -40,72 +38,4 @@ def exception_source(self) -> Optional[ExceptionSource]: return source_finder.find_class_from_object(self.cls) -class RelayWrongAnnotationError(StrawberryException): - def __init__(self, field_name: str, cls: Type) -> None: - self.cls = cls - self.field_name = field_name - - self.message = ( - f'Wrong annotation used on field "{field_name}". It should be ' - 'annotated with a "Connection" subclass.' - ) - self.rich_message = ( - f"Wrong annotation for field `[underline]{self.field_name}[/]`" - ) - self.suggestion = ( - "To fix this error you can add a valid annotation, " - f"like [italic]`{self.field_name}: relay.Connection[{cls}]` " - f"or [italic]`@relay.connection(relay.Connection[{cls}])`" - ) - self.annotation_message = "relay wrong annotation" - - super().__init__(self.message) - - @cached_property - def exception_source(self) -> Optional[ExceptionSource]: - if self.cls is None: - return None # pragma: no cover - - source_finder = SourceFinder() - return source_finder.find_class_attribute_from_object(self.cls, self.field_name) - - -class RelayWrongResolverAnnotationError(StrawberryException): - def __init__(self, field_name: str, resolver: StrawberryResolver) -> None: - self.function = resolver.wrapped_func - self.field_name = field_name - - self.message = ( - f'Wrong annotation used on "{field_name}" resolver. ' - "It should be return an iterable or async iterable object." - ) - self.rich_message = ( - f"Wrong annotation used on `{field_name}` resolver. " - "It should be return an `iterable` or `async iterable` object." - ) - self.suggestion = ( - "To fix this error you can annootate your resolver to return " - "one of the following options: `List[]`, " - "`Iterator[]`, `Iterable[]`, " - "`AsyncIterator[]`, `AsyncIterable[]`, " - "`Generator[, Any, Any]` and " - "`AsyncGenerator[, Any]`." - ) - self.annotation_message = "relay wrong resolver annotation" - - super().__init__(self.message) - - @cached_property - def exception_source(self) -> Optional[ExceptionSource]: - if self.function is None: - return None # pragma: no cover - - source_finder = SourceFinder() - return source_finder.find_function_from_object(cast(Callable, self.function)) - - -__all__ = [ - "NodeIDAnnotationError", - "RelayWrongAnnotationError", - "RelayWrongResolverAnnotationError", -] +__all__ = ["NodeIDAnnotationError"] diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index 5af00700f8..13fe48cdb1 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -1,57 +1,39 @@ from __future__ import annotations import asyncio -import dataclasses import inspect +import warnings from collections import defaultdict -from collections.abc import AsyncIterable from typing import ( TYPE_CHECKING, Any, - AsyncIterator, Awaitable, Callable, DefaultDict, Dict, - ForwardRef, - Iterable, Iterator, List, - Mapping, - Optional, - Sequence, Tuple, Type, Union, cast, - overload, ) -from typing_extensions import Annotated, get_origin +from typing_extensions import Annotated -from strawberry.annotation import StrawberryAnnotation from strawberry.extensions.field_extension import ( - AsyncExtensionResolver, FieldExtension, SyncExtensionResolver, ) -from strawberry.relay.exceptions import ( - RelayWrongAnnotationError, - RelayWrongResolverAnnotationError, -) -from strawberry.types.arguments import StrawberryArgument, argument +from strawberry.pagination.fields import ConnectionExtension, connection +from strawberry.types.arguments import argument # noqa: TCH001 from strawberry.types.base import StrawberryList, StrawberryOptional -from strawberry.types.field import _RESOLVER_TYPE, StrawberryField, field +from strawberry.types.field import StrawberryField, field from strawberry.types.fields.resolver import StrawberryResolver -from strawberry.types.lazy_type import LazyType from strawberry.utils.aio import asyncgen_to_list -from strawberry.utils.typing import eval_type, is_generic_alias -from .types import Connection, GlobalID, Node, NodeIterableType, NodeType +from .types import GlobalID, Node if TYPE_CHECKING: - from typing_extensions import Literal - - from strawberry.permission import BasePermission from strawberry.types.info import Info @@ -182,142 +164,6 @@ async def resolve(resolved: Any = resolved_nodes) -> List[Node]: return resolver -class ConnectionExtension(FieldExtension): - connection_type: Type[Connection[Node]] - - def apply(self, field: StrawberryField) -> None: - field.arguments = [ - *field.arguments, - StrawberryArgument( - python_name="before", - graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[str]), - description=( - "Returns the items in the list that come before the " - "specified cursor." - ), - default=None, - ), - StrawberryArgument( - python_name="after", - graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[str]), - description=( - "Returns the items in the list that come after the " - "specified cursor." - ), - default=None, - ), - StrawberryArgument( - python_name="first", - graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[int]), - description="Returns the first n items from the list.", - default=None, - ), - StrawberryArgument( - python_name="last", - graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[int]), - description=( - "Returns the items in the list that come after the " - "specified cursor." - ), - default=None, - ), - ] - - f_type = field.type - - if isinstance(f_type, LazyType): - f_type = f_type.resolve_type() - field.type = f_type - - type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type - if not isinstance(type_origin, type) or not issubclass(type_origin, Connection): - raise RelayWrongAnnotationError(field.name, cast(type, field.origin)) - - assert field.base_resolver - # TODO: We are not using resolver_type.type because it will call - # StrawberryAnnotation.resolve, which will strip async types from the - # type (i.e. AsyncGenerator[Fruit] will become Fruit). This is done there - # for subscription support, but we can't use it here. Maybe we can refactor - # this in the future. - resolver_type = field.base_resolver.signature.return_annotation - if isinstance(resolver_type, str): - resolver_type = ForwardRef(resolver_type) - if isinstance(resolver_type, ForwardRef): - resolver_type = eval_type( - resolver_type, - field.base_resolver._namespace, - None, - ) - - origin = get_origin(resolver_type) - if origin is None or not issubclass( - origin, (Iterator, Iterable, AsyncIterator, AsyncIterable) - ): - raise RelayWrongResolverAnnotationError(field.name, field.base_resolver) - - self.connection_type = cast(Type[Connection[Node]], field.type) - - def resolve( - self, - next_: SyncExtensionResolver, - source: Any, - info: Info, - *, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - **kwargs: Any, - ) -> Any: - assert self.connection_type is not None - return self.connection_type.resolve_connection( - cast(Iterable[Node], next_(source, info, **kwargs)), - info=info, - before=before, - after=after, - first=first, - last=last, - ) - - async def resolve_async( - self, - next_: AsyncExtensionResolver, - source: Any, - info: Info, - *, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - **kwargs: Any, - ) -> Any: - assert self.connection_type is not None - nodes = next_(source, info, **kwargs) - # nodes might be an AsyncIterable/AsyncIterator - # In this case we don't await for it - if inspect.isawaitable(nodes): - nodes = await nodes - - resolved = self.connection_type.resolve_connection( - cast(Iterable[Node], nodes), - info=info, - before=before, - after=after, - first=first, - last=last, - ) - - # If nodes was an AsyncIterable/AsyncIterator, resolve_connection - # will return a coroutine which we need to await - if inspect.isawaitable(resolved): - resolved = await resolved - return resolved - - if TYPE_CHECKING: node = field else: @@ -327,155 +173,22 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField: return field(*args, **kwargs) -@overload -def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, - *, - resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None, - name: Optional[str] = None, - is_subscription: bool = False, - description: Optional[str] = None, - init: Literal[True] = True, - permission_classes: Optional[List[Type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, - default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: List[FieldExtension] = (), # type: ignore -) -> Any: ... - - -@overload -def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, - *, - name: Optional[str] = None, - is_subscription: bool = False, - description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, - default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: List[FieldExtension] = (), # type: ignore -) -> StrawberryField: ... - - -def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, - *, - resolver: Optional[_RESOLVER_TYPE[Any]] = None, - name: Optional[str] = None, - is_subscription: bool = False, - description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, - default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: List[FieldExtension] = (), # type: ignore - # This init parameter is used by pyright to determine whether this field - # is added in the constructor or not. It is not used to change - # any behavior at the moment. - init: Literal[True, False, None] = None, -) -> Any: - """Annotate a property or a method to create a relay connection field. - - Relay connections are mostly used for pagination purposes. This decorator - helps creating a complete relay endpoint that provides default arguments - and has a default implementation for the connection slicing. - - Note that when setting a resolver to this field, it is expected for this - resolver to return an iterable of the expected node type, not the connection - itself. That iterable will then be paginated accordingly. So, the main use - case for this is to provide a filtered iterable of nodes by using some custom - filter arguments. - - Args: - graphql_type: The type of the nodes in the connection. This is used to - determine the type of the edges and the node field in the connection. - resolver: The resolver for the connection. This is expected to return an - iterable of the expected node type. - name: The GraphQL name of the field. - is_subscription: Whether the field is a subscription. - description: The GraphQL description of the field. - permission_classes: The permission classes to apply to the field. - deprecation_reason: The deprecation reason of the field. - default: The default value of the field. - default_factory: The default factory of the field. - metadata: The metadata of the field. - directives: The directives to apply to the field. - extensions: The extensions to apply to the field. - init: Used only for type checking purposes. - - Examples: - Annotating something like this: - - ```python - @strawberry.type - class X: - some_node: relay.Connection[SomeType] = relay.connection( - resolver=get_some_nodes, - description="ABC", +_DEPRECATIONS = { + "ConnectionField": ConnectionExtension, + "connection": connection, +} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATIONS: + warnings.warn( + f"{name} should be imported from strawberry.pagination.fields", + DeprecationWarning, + stacklevel=2, ) + return _DEPRECATIONS[name] - @relay.connection(relay.Connection[SomeType], description="ABC") - def get_some_nodes(self, age: int) -> Iterable[SomeType]: ... - ``` - - Will produce a query like this: - - ```graphql - query { - someNode ( - before: String - after: String - first: String - after: String - age: Int - ) { - totalCount - pageInfo { - hasNextPage - hasPreviousPage - startCursor - endCursor - } - edges { - cursor - node { - id - ... - } - } - } - } - ``` - - .. _Relay connections: - https://relay.dev/graphql/connections.htm - - """ - f = StrawberryField( - python_name=None, - graphql_name=name, - description=description, - type_annotation=StrawberryAnnotation.from_annotation(graphql_type), - is_subscription=is_subscription, - permission_classes=permission_classes or [], - deprecation_reason=deprecation_reason, - default=default, - default_factory=default_factory, - metadata=metadata, - directives=directives or (), - extensions=[*extensions, ConnectionExtension()], - ) - if resolver is not None: - f = f(resolver) - return f - - -__all__ = ["node", "connection"] + raise AttributeError(f"module {__name__} has no attribute {name}") + + +__all__ = ["NodeExtension", "node"] diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py index 869e017be9..29e75fb5b2 100644 --- a/strawberry/relay/types.py +++ b/strawberry/relay/types.py @@ -2,22 +2,16 @@ import dataclasses import inspect -import itertools import sys +import warnings from typing import ( TYPE_CHECKING, Any, - AsyncIterable, - AsyncIterator, Awaitable, ClassVar, ForwardRef, - Generic, Iterable, - Iterator, - List, Optional, - Sequence, Type, TypeVar, Union, @@ -26,44 +20,37 @@ ) from typing_extensions import Annotated, Literal, Self, TypeAlias, get_args, get_origin +from strawberry.pagination.types import ( + PREFIX, + Connection, + Edge, + ListConnection, + NodeIterableType, + NodeType, + PageInfo, +) +from strawberry.pagination.utils import ( + from_base64, + to_base64, +) from strawberry.relay.exceptions import NodeIDAnnotationError from strawberry.types.base import ( - StrawberryContainer, StrawberryObjectDefinition, - get_object_definition, ) from strawberry.types.field import field from strawberry.types.info import Info # noqa: TCH001 from strawberry.types.lazy_type import LazyType -from strawberry.types.object_type import interface, type +from strawberry.types.object_type import interface from strawberry.types.private import StrawberryPrivate -from strawberry.utils.aio import aenumerate, aislice, resolve_awaitable -from strawberry.utils.inspect import in_async_context +from strawberry.utils.aio import resolve_awaitable from strawberry.utils.typing import eval_type, is_classvar -from .utils import ( - SliceMetadata, - from_base64, - should_resolve_list_connection_edges, - to_base64, -) - if TYPE_CHECKING: from strawberry.scalars import ID from strawberry.utils.await_maybe import AwaitableOrValue _T = TypeVar("_T") -NodeIterableType: TypeAlias = Union[ - Iterator[_T], - Iterable[_T], - AsyncIterator[_T], - AsyncIterable[_T], -] -NodeType = TypeVar("NodeType", bound="Node") - -PREFIX = "arrayconnection" - class GlobalIDValueError(ValueError): """GlobalID value error, usually related to parsing or serialization.""" @@ -618,320 +605,27 @@ def resolve_node( return next(iter(cast(Iterable[Self], retval))) -@type(description="Information to aid in pagination.") -class PageInfo: - """Information to aid in pagination. - - Attributes: - has_next_page: - When paginating forwards, are there more items? - has_previous_page: - When paginating backwards, are there more items? - start_cursor: - When paginating backwards, the cursor to continue - end_cursor: - When paginating forwards, the cursor to continue - """ - - has_next_page: bool = field( - description="When paginating forwards, are there more items?", - ) - has_previous_page: bool = field( - description="When paginating backwards, are there more items?", - ) - start_cursor: Optional[str] = field( - description="When paginating backwards, the cursor to continue.", - ) - end_cursor: Optional[str] = field( - description="When paginating forwards, the cursor to continue.", - ) - - -@type(description="An edge in a connection.") -class Edge(Generic[NodeType]): - """An edge in a connection. - - Attributes: - cursor: - A cursor for use in pagination - node: - The item at the end of the edge - """ - - cursor: str = field(description="A cursor for use in pagination") - node: NodeType = field(description="The item at the end of the edge") - - @classmethod - def resolve_edge(cls, node: NodeType, *, cursor: Any = None) -> Self: - return cls(cursor=to_base64(PREFIX, cursor), node=node) - - -@type(description="A connection to a list of items.") -class Connection(Generic[NodeType]): - """A connection to a list of items. - - Attributes: - page_info: - Pagination data for this connection - edges: - Contains the nodes in this connection - - """ - - page_info: PageInfo = field(description="Pagination data for this connection") - edges: List[Edge[NodeType]] = field( - description="Contains the nodes in this connection" - ) - - @classmethod - def resolve_node(cls, node: Any, *, info: Info, **kwargs: Any) -> NodeType: - """The identity function for the node. - - This method is used to resolve a node of a different type to the - connection's `NodeType`. - - By default it returns the node itself, but subclasses can override - this to provide a custom implementation. - - Args: - node: - The resolved node which should return an instance of this - connection's `NodeType`. - info: - The strawberry execution info resolve the type name from. - **kwargs: - Additional arguments passed to the resolver. - - """ - return node - - @classmethod - def resolve_connection( - cls, - nodes: NodeIterableType[NodeType], - *, - info: Info, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - **kwargs: Any, - ) -> AwaitableOrValue[Self]: - """Resolve a connection from nodes. - - Subclasses must define this method to paginate nodes based - on `first`/`last`/`before`/`after` arguments. - - Args: - info: The strawberry execution info resolve the type name from. - nodes: An iterable/iteretor of nodes to paginate. - before: Returns the items in the list that come before the specified cursor. - after: Returns the items in the list that come after the specified cursor. - first: Returns the first n items from the list. - last: Returns the items in the list that come after the specified cursor. - kwargs: Additional arguments passed to the resolver. - - Returns: - The resolved `Connection` - - """ - raise NotImplementedError - - -@type(name="Connection", description="A connection to a list of items.") -class ListConnection(Connection[NodeType]): - """A connection to a list of items. - - Attributes: - page_info: - Pagination data for this connection - edges: - Contains the nodes in this connection - - """ - - page_info: PageInfo = field(description="Pagination data for this connection") - edges: List[Edge[NodeType]] = field( - description="Contains the nodes in this connection" - ) - - @classmethod - def resolve_connection( - cls, - nodes: NodeIterableType[NodeType], - *, - info: Info, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - **kwargs: Any, - ) -> AwaitableOrValue[Self]: - """Resolve a connection from the list of nodes. - - This uses the described Relay Pagination algorithm_ - - Args: - info: The strawberry execution info resolve the type name from. - nodes: An iterable/iteretor of nodes to paginate. - before: Returns the items in the list that come before the specified cursor. - after: Returns the items in the list that come after the specified cursor. - first: Returns the first n items from the list. - last: Returns the items in the list that come after the specified cursor. - kwargs: Additional arguments passed to the resolver. +_DEPRECATIONS = { + "Connection": Connection, + "Edge": Edge, + "ListConnection": ListConnection, + "NodeIterableType": NodeIterableType, + "NodeType": NodeType, + "PREFIX": PREFIX, + "PageInfo": PageInfo, +} - Returns: - The resolved `Connection` - .. _Relay Pagination algorithm: - https://relay.dev/graphql/connections.htm#sec-Pagination-algorithm - """ - slice_metadata = SliceMetadata.from_arguments( - info, - before=before, - after=after, - first=first, - last=last, +def __getattr__(name: str) -> Any: + if name in _DEPRECATIONS: + warnings.warn( + f"{name} should be imported from strawberry.pagination.types", + DeprecationWarning, + stacklevel=2, ) + return _DEPRECATIONS[name] - type_def = get_object_definition(cls) - assert type_def - field_def = type_def.get_field("edges") - assert field_def - - field = field_def.resolve_type(type_definition=type_def) - while isinstance(field, StrawberryContainer): - field = field.of_type - - edge_class = cast(Edge[NodeType], field) - - if isinstance(nodes, (AsyncIterator, AsyncIterable)) and in_async_context(): - - async def resolver() -> Self: - try: - iterator = cast( - Union[AsyncIterator[NodeType], AsyncIterable[NodeType]], - cast(Sequence, nodes)[ - slice_metadata.start : slice_metadata.overfetch - ], - ) - except TypeError: - # TODO: Why mypy isn't narrowing this based on the if above? - assert isinstance(nodes, (AsyncIterator, AsyncIterable)) - iterator = aislice( - nodes, - slice_metadata.start, - slice_metadata.overfetch, - ) - - # The slice above might return an object that now is not async - # iterable anymore (e.g. an already cached django queryset) - if isinstance(iterator, (AsyncIterator, AsyncIterable)): - edges: List[Edge] = [ - edge_class.resolve_edge( - cls.resolve_node(v, info=info, **kwargs), - cursor=slice_metadata.start + i, - ) - async for i, v in aenumerate(iterator) - ] - else: - edges: List[Edge] = [ # type: ignore[no-redef] - edge_class.resolve_edge( - cls.resolve_node(v, info=info, **kwargs), - cursor=slice_metadata.start + i, - ) - for i, v in enumerate(iterator) - ] - - has_previous_page = slice_metadata.start > 0 - if ( - slice_metadata.expected is not None - and len(edges) == slice_metadata.expected + 1 - ): - # Remove the overfetched result - edges = edges[:-1] - has_next_page = True - elif slice_metadata.end == sys.maxsize: - # Last was asked without any after/before - assert last is not None - original_len = len(edges) - edges = edges[-last:] - has_next_page = False - has_previous_page = len(edges) != original_len - else: - has_next_page = False - - return cls( - edges=edges, - page_info=PageInfo( - start_cursor=edges[0].cursor if edges else None, - end_cursor=edges[-1].cursor if edges else None, - has_previous_page=has_previous_page, - has_next_page=has_next_page, - ), - ) - - return resolver() - - try: - iterator = cast( - Union[Iterator[NodeType], Iterable[NodeType]], - cast(Sequence, nodes)[slice_metadata.start : slice_metadata.overfetch], - ) - except TypeError: - assert isinstance(nodes, (Iterable, Iterator)) - iterator = itertools.islice( - nodes, - slice_metadata.start, - slice_metadata.overfetch, - ) - - if not should_resolve_list_connection_edges(info): - return cls( - edges=[], - page_info=PageInfo( - start_cursor=None, - end_cursor=None, - has_previous_page=False, - has_next_page=False, - ), - ) - - edges = [ - edge_class.resolve_edge( - cls.resolve_node(v, info=info, **kwargs), - cursor=slice_metadata.start + i, - ) - for i, v in enumerate(iterator) - ] - - has_previous_page = slice_metadata.start > 0 - if ( - slice_metadata.expected is not None - and len(edges) == slice_metadata.expected + 1 - ): - # Remove the overfetched result - edges = edges[:-1] - has_next_page = True - elif slice_metadata.end == sys.maxsize: - # Last was asked without any after/before - assert last is not None - original_len = len(edges) - edges = edges[-last:] - has_next_page = False - has_previous_page = len(edges) != original_len - else: - has_next_page = False - - return cls( - edges=edges, - page_info=PageInfo( - start_cursor=edges[0].cursor if edges else None, - end_cursor=edges[-1].cursor if edges else None, - has_previous_page=has_previous_page, - has_next_page=has_next_page, - ), - ) + raise AttributeError(f"module {__name__} has no attribute {name}") __all__ = [ @@ -941,11 +635,4 @@ async def resolver() -> Self: "NodeID", "NodeIDAnnotationError", "NodeIDPrivate", - "NodeIterableType", - "NodeType", - "PREFIX", - "Connection", - "Edge", - "PageInfo", - "ListConnection", ] diff --git a/strawberry/relay/utils.py b/strawberry/relay/utils.py index c25bd537b0..ca3739b661 100644 --- a/strawberry/relay/utils.py +++ b/strawberry/relay/utils.py @@ -1,198 +1,30 @@ -from __future__ import annotations - -import base64 -import dataclasses -import sys -from typing import TYPE_CHECKING, Any, Tuple, Union -from typing_extensions import Self, assert_never - -from strawberry.types.base import StrawberryObjectDefinition -from strawberry.types.nodes import InlineFragment, Selection - -if TYPE_CHECKING: - from strawberry.types.info import Info - - -def from_base64(value: str) -> Tuple[str, str]: - """Parse the base64 encoded relay value. - - Args: - value: - The value to be parsed - - Returns: - A tuple of (TypeName, NodeID). - - Raises: - ValueError: - If the value is not in the expected format - - """ - try: - res = base64.b64decode(value.encode()).decode().split(":", 1) - except Exception as e: - raise ValueError(str(e)) from e - - if len(res) != 2: - raise ValueError(f"{res} expected to contain only 2 items") - - return res[0], res[1] - - -def to_base64(type_: Union[str, type, StrawberryObjectDefinition], node_id: Any) -> str: - """Encode the type name and node id to a base64 string. - - Args: - type_: - The GraphQL type, type definition or type name. - node_id: - The node id itself - - Returns: - A GlobalID, which is a string resulting from base64 encoding :. - - Raises: - ValueError: - If the value is not a valid GraphQL type or name - - """ - try: - if isinstance(type_, str): - type_name = type_ - elif isinstance(type_, StrawberryObjectDefinition): - type_name = type_.name - elif isinstance(type_, type): - type_name = type_.__strawberry_definition__.name # type:ignore - else: # pragma: no cover - assert_never(type_) - except Exception as e: - raise ValueError(f"{type_} is not a valid GraphQL type or name") from e - - return base64.b64encode(f"{type_name}:{node_id}".encode()).decode() - - -def should_resolve_list_connection_edges(info: Info) -> bool: - """Check if the user requested to resolve the `edges` field of a connection. - - Args: - info: - The strawberry execution info resolve the type name from - - Returns: - True if the user requested to resolve the `edges` field of a connection, False otherwise. - - """ - resolve_for_field_names = {"edges", "pageInfo"} - - def _check_selection(selection: Selection) -> bool: - """Recursively inspect the selection to check if the user requested to resolve the `edges` field. - - Args: - selection (Selection): The selection to check. - - Returns: - bool: True if the user requested to resolve the `edges` field of a connection, False otherwise. - """ - if ( - not isinstance(selection, InlineFragment) - and selection.name in resolve_for_field_names - ): - return True - if selection.selections: - return any( - _check_selection(selection) for selection in selection.selections - ) - return False - - for selection_field in info.selected_fields: - for selection in selection_field.selections: - if _check_selection(selection): - return True - return False - - -@dataclasses.dataclass -class SliceMetadata: - start: int - end: int - expected: int | None - - @property - def overfetch(self) -> int: - # Overfetch by 1 to check if we have a next result - return self.end + 1 if self.end != sys.maxsize else self.end - - @classmethod - def from_arguments( - cls, - info: Info, - *, - before: str | None = None, - after: str | None = None, - first: int | None = None, - last: int | None = None, - ) -> Self: - """Get the slice metadata to use on ListConnection.""" - from strawberry.relay.types import PREFIX - - max_results = info.schema.config.relay_max_results - start = 0 - end: int | None = None - - if after: - after_type, after_parsed = from_base64(after) - if after_type != PREFIX: - raise TypeError("Argument 'after' contains a non-existing value.") - - start = int(after_parsed) + 1 - if before: - before_type, before_parsed = from_base64(before) - if before_type != PREFIX: - raise TypeError("Argument 'before' contains a non-existing value.") - end = int(before_parsed) - - if isinstance(first, int): - if first < 0: - raise ValueError("Argument 'first' must be a non-negative integer.") - - if first > max_results: - raise ValueError( - f"Argument 'first' cannot be higher than {max_results}." - ) - - if end is not None: - start = max(0, end - 1) - - end = start + first - if isinstance(last, int): - if last < 0: - raise ValueError("Argument 'last' must be a non-negative integer.") - - if last > max_results: - raise ValueError( - f"Argument 'last' cannot be higher than {max_results}." - ) - - if end is not None: - start = max(start, end - last) - else: - end = sys.maxsize - - if end is None: - end = start + max_results - - expected = end - start if end != sys.maxsize else None - - return cls( - start=start, - end=end, - expected=expected, +import warnings +from typing import Any + +from strawberry.pagination.utils import ( + SliceMetadata, + from_base64, + should_resolve_list_connection_edges, + to_base64, +) + +__all__ = [] + +_DEPRECATIONS = { + "SliceMetadata": SliceMetadata, + "from_base64": from_base64, + "should_resolve_list_connection_edges": should_resolve_list_connection_edges, + "to_base64": to_base64, +} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATIONS: + warnings.warn( + f"{name} should be imported from strawberry.pagination.utils", + DeprecationWarning, + stacklevel=2, ) + return _DEPRECATIONS[name] - -__all__ = [ - "from_base64", - "to_base64", - "should_resolve_list_connection_edges", - "SliceMetadata", -] + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/strawberry/utils.py b/strawberry/utils.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pagination/__init__.py b/tests/pagination/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pagination/test_exceptions.py b/tests/pagination/test_exceptions.py new file mode 100644 index 0000000000..bdf735dfef --- /dev/null +++ b/tests/pagination/test_exceptions.py @@ -0,0 +1,70 @@ +from typing import List + +import pytest + +import strawberry +from strawberry.pagination.exceptions import ( + ConnectionWrongAnnotationError, + ConnectionWrongResolverAnnotationError, +) +from strawberry.pagination.fields import connection +from strawberry.pagination.types import Connection + + +@pytest.mark.raises_strawberry_exception( + ConnectionWrongAnnotationError, + match=( + 'Wrong annotation used on field "fruits_conn". ' + 'It should be annotated with a "Connection" subclass.' + ), +) +def test_raises_error_on_connection_missing_annotation(): + @strawberry.type + class Fruit: + pk: str + + @strawberry.type + class Query: + fruits_conn: List[Fruit] = connection() + + strawberry.Schema(query=Query) + + +@pytest.mark.raises_strawberry_exception( + ConnectionWrongAnnotationError, + match=( + 'Wrong annotation used on field "custom_resolver". ' + 'It should be annotated with a "Connection" subclass.' + ), +) +def test_raises_error_on_connection_wrong_annotation(): + @strawberry.type + class Fruit: + pk: str + + @strawberry.type + class Query: + @connection(List[Fruit]) # type: ignore + def custom_resolver(self) -> List[Fruit]: ... + + strawberry.Schema(query=Query) + + +@pytest.mark.raises_strawberry_exception( + ConnectionWrongResolverAnnotationError, + match=( + 'Wrong annotation used on "custom_resolver" resolver. ' + "It should be return an iterable or async iterable object." + ), +) +def test_raises_error_on_connection_resolver_wrong_annotation(): + @strawberry.type + class Fruit: + pk: str + + @strawberry.type + class Query: + @connection(Connection[Fruit]) # type: ignore + def custom_resolver(self): ... + + strawberry.Schema(query=Query) diff --git a/tests/pagination/test_types.py b/tests/pagination/test_types.py new file mode 100644 index 0000000000..dafa73e36f --- /dev/null +++ b/tests/pagination/test_types.py @@ -0,0 +1,146 @@ +from collections.abc import Iterable +from typing import AsyncGenerator, AsyncIterable + +from pytest_mock import MockerFixture + +import strawberry +from strawberry.pagination.fields import connection +from strawberry.pagination.types import ListConnection, NodeType + + +async def test_resolve_async_list_connection(): + @strawberry.type + class SomeType: + id: int + + @strawberry.type + class Query: + @connection(ListConnection[SomeType]) + async def some_type_conn(self) -> AsyncGenerator[SomeType, None]: + yield SomeType(id=0) + yield SomeType(id=1) + yield SomeType(id=2) + + schema = strawberry.Schema(query=Query) + ret = await schema.execute( + """\ + query { + someTypeConn { + edges { + node { + id + } + } + } + } + """ + ) + assert ret.errors is None + assert ret.data == { + "someTypeConn": { + "edges": [ + {"node": {"id": 0}}, + {"node": {"id": 1}}, + {"node": {"id": 2}}, + ], + } + } + + +async def test_resolve_async_list_connection_but_sync_after_sliced(): + # We are mimicking an object which is async iterable, but when sliced + # returns something that is not anymore. This is similar to an already + # prefetched django QuerySet, which is async iterable by default, but + # when sliced, since it is already prefetched, will return a list. + class Slicer: + def __init__(self, nodes) -> None: + self.nodes = nodes + + async def __aiter__(self): + for n in self.nodes: + yield n + + def __getitem__(self, key): + return self.nodes[key] + + @strawberry.type + class SomeType: + id: int + + @strawberry.type + class Query: + @connection(ListConnection[SomeType]) + async def some_type_conn(self) -> AsyncIterable[SomeType]: + return Slicer([SomeType(id=0), SomeType(id=1), SomeType(id=2)]) + + schema = strawberry.Schema(query=Query) + ret = await schema.execute( + """\ + query { + someTypeConn { + edges { + node { + id + } + } + } + } + """ + ) + assert ret.errors is None + assert ret.data == { + "someTypeConn": { + "edges": [ + {"node": {"id": 0}}, + {"node": {"id": 1}}, + {"node": {"id": 2}}, + ], + } + } + + +def test_list_connection_without_edges_or_page_info_should_not_call_resolver( + mocker: MockerFixture, +): + resolve_edge_mock = mocker.patch("strawberry.pagination.types.Edge.resolve_edge") + + @strawberry.type(name="Connection", description="A connection to a list of items.") + class DummyListConnectionWithTotalCount(ListConnection[NodeType]): + @strawberry.field(description="Total quantity of existing nodes.") + def total_count(self) -> int: + return -1 + + @strawberry.type + class Fruit: + id: int + + def fruits_resolver() -> Iterable[Fruit]: + return [Fruit(id=1), Fruit(id=2), Fruit(id=3), Fruit(id=4), Fruit(id=5)] + + fruits_resolver_spy = mocker.spy(fruits_resolver, "__call__") + + @strawberry.type + class Query: + fruits: DummyListConnectionWithTotalCount[Fruit] = connection( + resolver=fruits_resolver + ) + + schema = strawberry.Schema(query=Query) + ret = schema.execute_sync( + """ + query { + fruits { + totalCount + } + } + """ + ) + assert ret.errors is None + assert ret.data == { + "fruits": { + "totalCount": -1, + } + } + + resolve_edge_mock.assert_not_called() + fruits_resolver_spy.assert_not_called() diff --git a/tests/relay/test_utils.py b/tests/pagination/test_utils.py similarity index 94% rename from tests/relay/test_utils.py rename to tests/pagination/test_utils.py index e14a375522..13eee7402e 100644 --- a/tests/relay/test_utils.py +++ b/tests/pagination/test_utils.py @@ -6,8 +6,9 @@ import pytest -from strawberry.relay.types import PREFIX -from strawberry.relay.utils import ( +import strawberry +from strawberry.pagination.types import PREFIX +from strawberry.pagination.utils import ( SliceMetadata, from_base64, to_base64, @@ -15,8 +16,6 @@ from strawberry.schema.config import StrawberryConfig from strawberry.types.base import get_object_definition -from .schema import Fruit - def test_from_base64(): type_name, node_id = from_base64("Zm9vYmFyOjE=") # foobar:1 @@ -55,11 +54,19 @@ def test_to_base64(): def test_to_base64_with_type(): + @strawberry.type + class Fruit: + id: int + value = to_base64(Fruit, "1") assert value == "RnJ1aXQ6MQ==" def test_to_base64_with_typedef(): + @strawberry.type + class Fruit: + id: int + value = to_base64( get_object_definition(Fruit, strict=True), "1", diff --git a/tests/relay/test_exceptions.py b/tests/relay/test_exceptions.py index 174c3748f7..ad0dfe0ae2 100644 --- a/tests/relay/test_exceptions.py +++ b/tests/relay/test_exceptions.py @@ -5,11 +5,7 @@ import strawberry from strawberry import Info, relay from strawberry.relay import GlobalID -from strawberry.relay.exceptions import ( - NodeIDAnnotationError, - RelayWrongAnnotationError, - RelayWrongResolverAnnotationError, -) +from strawberry.relay.exceptions import NodeIDAnnotationError @strawberry.type @@ -96,62 +92,3 @@ class Query: def fruits(self) -> List[Fruit]: ... strawberry.Schema(query=Query) - - -@pytest.mark.raises_strawberry_exception( - RelayWrongAnnotationError, - match=( - 'Wrong annotation used on field "fruits_conn". ' - 'It should be annotated with a "Connection" subclass.' - ), -) -def test_raises_error_on_connection_missing_annotation(): - @strawberry.type - class Fruit(relay.Node): - pk: relay.NodeID[str] - - @strawberry.type - class Query: - fruits_conn: List[Fruit] = relay.connection() - - strawberry.Schema(query=Query) - - -@pytest.mark.raises_strawberry_exception( - RelayWrongAnnotationError, - match=( - 'Wrong annotation used on field "custom_resolver". ' - 'It should be annotated with a "Connection" subclass.' - ), -) -def test_raises_error_on_connection_wrong_annotation(): - @strawberry.type - class Fruit(relay.Node): - pk: relay.NodeID[str] - - @strawberry.type - class Query: - @relay.connection(List[Fruit]) # type: ignore - def custom_resolver(self) -> List[Fruit]: ... - - strawberry.Schema(query=Query) - - -@pytest.mark.raises_strawberry_exception( - RelayWrongResolverAnnotationError, - match=( - 'Wrong annotation used on "custom_resolver" resolver. ' - "It should be return an iterable or async iterable object." - ), -) -def test_raises_error_on_connection_resolver_wrong_annotation(): - @strawberry.type - class Fruit(relay.Node): - pk: relay.NodeID[str] - - @strawberry.type - class Query: - @relay.connection(relay.Connection[Fruit]) # type: ignore - def custom_resolver(self): ... - - strawberry.Schema(query=Query) diff --git a/tests/relay/test_types.py b/tests/relay/test_types.py index 1171d7896a..cb9b4a4408 100644 --- a/tests/relay/test_types.py +++ b/tests/relay/test_types.py @@ -1,15 +1,13 @@ -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Union, cast +from typing import Any, Optional, Union, cast from typing_extensions import assert_type -from unittest.mock import MagicMock import pytest import strawberry from strawberry import relay -from strawberry.relay.utils import to_base64 from strawberry.types.info import Info -from .schema import Fruit, FruitAsync, fruits_resolver, schema +from .schema import Fruit, FruitAsync, schema class FakeInfo: @@ -149,97 +147,6 @@ class Foo: ... fruit = await gid.resolve_node(fake_info, ensure_type=Foo) -async def test_resolve_async_list_connection(): - @strawberry.type - class SomeType(relay.Node): - id: relay.NodeID[int] - - @strawberry.type - class Query: - @relay.connection(relay.ListConnection[SomeType]) - async def some_type_conn(self) -> AsyncGenerator[SomeType, None]: - yield SomeType(id=0) - yield SomeType(id=1) - yield SomeType(id=2) - - schema = strawberry.Schema(query=Query) - ret = await schema.execute( - """\ - query { - someTypeConn { - edges { - node { - id - } - } - } - } - """ - ) - assert ret.errors is None - assert ret.data == { - "someTypeConn": { - "edges": [ - {"node": {"id": to_base64("SomeType", 0)}}, - {"node": {"id": to_base64("SomeType", 1)}}, - {"node": {"id": to_base64("SomeType", 2)}}, - ], - } - } - - -async def test_resolve_async_list_connection_but_sync_after_sliced(): - # We are mimicking an object which is async iterable, but when sliced - # returns something that is not anymore. This is similar to an already - # prefetched django QuerySet, which is async iterable by default, but - # when sliced, since it is already prefetched, will return a list. - class Slicer: - def __init__(self, nodes) -> None: - self.nodes = nodes - - async def __aiter__(self): - for n in self.nodes: - yield n - - def __getitem__(self, key): - return self.nodes[key] - - @strawberry.type - class SomeType(relay.Node): - id: relay.NodeID[int] - - @strawberry.type - class Query: - @relay.connection(relay.ListConnection[SomeType]) - async def some_type_conn(self) -> AsyncIterable[SomeType]: - return Slicer([SomeType(id=0), SomeType(id=1), SomeType(id=2)]) - - schema = strawberry.Schema(query=Query) - ret = await schema.execute( - """\ - query { - someTypeConn { - edges { - node { - id - } - } - } - } - """ - ) - assert ret.errors is None - assert ret.data == { - "someTypeConn": { - "edges": [ - {"node": {"id": to_base64("SomeType", 0)}}, - {"node": {"id": to_base64("SomeType", 1)}}, - {"node": {"id": to_base64("SomeType", 2)}}, - ], - } - } - - def test_overwrite_resolve_id_and_no_node_id(): @strawberry.type class Fruit(relay.Node): @@ -258,39 +165,6 @@ def fruit(self) -> Fruit: strawberry.Schema(query=Query) -def test_list_connection_without_edges_or_page_info(mocker: MagicMock): - @strawberry.type(name="Connection", description="A connection to a list of items.") - class DummyListConnectionWithTotalCount(relay.ListConnection[relay.NodeType]): - @strawberry.field(description="Total quantity of existing nodes.") - def total_count(self) -> int: - return -1 - - @strawberry.type - class Query: - fruits: DummyListConnectionWithTotalCount[Fruit] = relay.connection( - resolver=fruits_resolver - ) - - mock = mocker.patch("strawberry.relay.types.Edge.resolve_edge") - schema = strawberry.Schema(query=Query) - ret = schema.execute_sync( - """ - query { - fruits { - totalCount - } - } - """ - ) - mock.assert_not_called() - assert ret.errors is None - assert ret.data == { - "fruits": { - "totalCount": -1, - } - } - - def test_list_connection_with_nested_fragments(): ret = schema.execute_sync( """