diff --git a/RELEASE.md b/RELEASE.md
new file mode 100644
index 0000000000..bf84d2ce69
--- /dev/null
+++ b/RELEASE.md
@@ -0,0 +1,16 @@
+Release type: minor
+
+This release moves the `Connection` implementation outside the relay package,
+allowing it to be used for general-purpose cursor pagination.
+
+The following now can be imported from `strawberry.pagination`:
+
+- `Connection` - base generic class for implementing connections
+- `ListConnection` - a limit-offset implementation of the connection
+- `connection` - field decorator for creating connections
+
+Those can still be used together with the relay package, but importing from it
+is now deprecated.
+
+You can read more about connections in the
+[Strawberry Connection Docs](https://strawberry.rocks/docs/guides/pagination/connections).
diff --git a/docs/errors/relay-wrong-annotation.md b/docs/errors/connection-wrong-annotation.md
similarity index 100%
rename from docs/errors/relay-wrong-annotation.md
rename to docs/errors/connection-wrong-annotation.md
diff --git a/docs/errors/relay-wrong-resolver-annotation.md b/docs/errors/connection-wrong-resolver-annotation.md
similarity index 69%
rename from docs/errors/relay-wrong-resolver-annotation.md
rename to docs/errors/connection-wrong-resolver-annotation.md
index 80c732dd95..1a99d767be 100644
--- a/docs/errors/relay-wrong-resolver-annotation.md
+++ b/docs/errors/connection-wrong-resolver-annotation.md
@@ -1,13 +1,13 @@
---
-title: Relay wrong resolver annotation Error
+title: Connection wrong resolver annotation Error
---
-# Relay wrong resolver annotation error
+# Connection wrong resolver annotation error
## Description
-This error is thrown when a field on a relay connection was defined with a
-resolver that returns something that is not compatible with pagination.
+This error is thrown when a field on a connection was defined with a resolver
+that returns something that is not compatible with pagination.
For example, the following code would throw this error:
@@ -15,19 +15,19 @@ For example, the following code would throw this error:
from typing import Any
import strawberry
-from strawberry import relay
+from strawberry.pagination import connection
@strawberry.type
-class MyType(relay.Node): ...
+class MyType(Node): ...
@strawberry.type
class Query:
- @relay.connection(relay.Connection[MyType])
+ @connection(Connection[MyType])
def some_connection_returning_mytype(self) -> MyType: ...
- @relay.connection(relay.Connection[MyType])
+ @connection(Connection[MyType])
def some_connection_returning_any(self) -> Any: ...
```
@@ -53,22 +53,22 @@ For example:
from typing import Any
import strawberry
-from strawberry import relay
+from strawberry.pagination import connection, Connection
@strawberry.type
-class MyType(relay.Node): ...
+class MyType: ...
@strawberry.type
class Query:
- @relay.connection(relay.Connection[MyType])
+ @connection(Connection[MyType])
def some_connection(self) -> Iterable[MyType]: ...
```
Note that if you are returning a type different than the connection type, you
will need to subclass the connection type and override its `resolve_node`
- method to convert it to the correct type, as explained in the [relay
- guide](../guides/relay).
+ method to convert it to the correct type, as explained in the [pagination
+ guide](../guides/pagination).
diff --git a/docs/guides/pagination/connections.md b/docs/guides/pagination/connections.md
index 37c6a89f3a..44aab943e1 100644
--- a/docs/guides/pagination/connections.md
+++ b/docs/guides/pagination/connections.md
@@ -1,14 +1,14 @@
---
-title: Pagination - Implementing the Relay Connection Specification
+title: Pagination - Implementing the Connection Specification
---
-# Implementing the Relay Connection Specification
+# Implementing the Connection Specification
We naively implemented cursor based pagination in the
[previous tutorial](./cursor-based.md). To ensure a consistent implementation of
this pattern, the Relay project has a formal
-[specification](https://relay.dev/graphql/connections.htm) you can follow for
-building GraphQL APIs which use a cursor based connection pattern.
+[connection specification](https://relay.dev/graphql/connections.htm), and
+strawberry provides a `Connection` generic type to implement it.
By the end of this tutorial, we should be able to return a connection of users
when requested.
@@ -74,348 +74,270 @@ query getUsers {
-## Connections
+## Connection
A Connection represents a paginated relationship between two entities. This
pattern is used when the relationship itself has attributes. For example, we
might have a connection of users to represent a paginated list of users.
-Let us define a Connection type which takes in a Generic ObjectType.
+Strawberry provides a `relay.Connection` class, which contains the basics of
+what you need to implement the specification, but don't implement any pagination
+logic. To use it, all you need to do is to subclass it and implement its
+abstract `resolve_connection` classmethod.
-```py
-# example.py
+
-from typing import Generic, TypeVar
+For basic cases, Strawberry provides a `ListConnection` which already implements
+a `resolve_connection`, which we are going to look in the next section.
-import strawberry
-
-
-GenericType = TypeVar("GenericType")
-
-
-@strawberry.type
-class Connection(Generic[GenericType]):
- page_info: "PageInfo" = strawberry.field(
- description="Information to aid in pagination."
- )
-
- edges: list["Edge[GenericType]"] = strawberry.field(
- description="A list of edges in this connection."
- )
-```
-
-Connections must have atleast two fields: `edges` and `page_info`.
-
-The `page_info` field contains metadata about the connection. Following the
-Relay specification, we can define a `PageInfo` type like this:
-
-```py line=22-38
-# example.py
+
-from typing import Generic, TypeVar
+Lets look at an example for a connection of users:
+```python
import strawberry
-
-
-GenericType = TypeVar("GenericType")
-
-
-@strawberry.type
-class Connection(Generic[GenericType]):
- page_info: "PageInfo" = strawberry.field(
- description="Information to aid in pagination."
- )
-
- edges: list["Edge[GenericType]"] = strawberry.field(
- description="A list of edges in this connection."
- )
+from strawberry.pagination import Connection, Edge, to_base64
@strawberry.type
-class PageInfo:
- has_next_page: bool = strawberry.field(
- description="When paginating forwards, are there more items?"
- )
-
- has_previous_page: bool = strawberry.field(
- description="When paginating backwards, are there more items?"
- )
-
- start_cursor: Optional[str] = strawberry.field(
- description="When paginating backwards, the cursor to continue."
- )
-
- end_cursor: Optional[str] = strawberry.field(
- description="When paginating forwards, the cursor to continue."
- )
-```
-
-You can read more about the `PageInfo` type at:
-
-- https://graphql.org/learn/pagination/#pagination-and-edges
-- https://relay.dev/graphql/connections.htm
-
-The `edges` field must return a list type that wraps an edge type.
-
-Following the Relay specification, let us define an Edge that takes in a generic
-ObjectType.
-
-```py line=41-49
-# example.py
-
-from typing import Generic, TypeVar
-
-import strawberry
+class UserConnection(Connection[User]):
+ @classmethod
+ def resolve_connection(
+ cls,
+ nodes: Iterable[Fruit],
+ *,
+ info: Optional[Info] = None,
+ before: Optional[str] = None,
+ after: Optional[str] = None,
+ first: Optional[int] = None,
+ last: Optional[int] = None,
+ ):
+ # NOTE: This is a showcase implementation and is far from
+ # being optimal performance wise
+ edges_mapping = {
+ to_base64("cursor-name", n.name): Edge(
+ node=n,
+ cursor=to_base64("cursor-name", n.name),
+ )
+ for n in sorted(nodes, key=lambda f: f.name)
+ }
+ edges = list(edges_mapping.values())
+ first_edge = edges[0] if edges else None
+ last_edge = edges[-1] if edges else None
+ if after is not None:
+ after_edge_idx = edges.index(edges_mapping[after])
+ edges = [e for e in edges if edges.index(e) > after_edge_idx]
-GenericType = TypeVar("GenericType")
+ if before is not None:
+ before_edge_idx = edges.index(edges_mapping[before])
+ edges = [e for e in edges if edges.index(e) < before_edge_idx]
+ if first is not None:
+ edges = edges[:first]
-@strawberry.type
-class Connection(Generic[GenericType]):
- page_info: "PageInfo" = strawberry.field(
- description="Information to aid in pagination."
- )
+ if last is not None:
+ edges = edges[-last:]
- edges: list["Edge[GenericType]"] = strawberry.field(
- description="A list of edges in this connection."
- )
+ return cls(
+ edges=edges,
+ page_info=strawberry.relay.PageInfo(
+ start_cursor=edges[0].cursor if edges else None,
+ end_cursor=edges[-1].cursor if edges else None,
+ has_previous_page=(
+ first_edge is not None and bool(edges) and edges[0] != first_edge
+ ),
+ has_next_page=(
+ last_edge is not None and bool(edges) and edges[-1] != last_edge
+ ),
+ ),
+ )
@strawberry.type
-class PageInfo:
- has_next_page: bool = strawberry.field(
- description="When paginating forwards, are there more items?"
- )
+class Query:
+ @connection(UserConnection)
+ def get_users(self) -> Iterable[User]:
+ # This can be a database query, a generator, an async generator, etc
+ return some_function_that_returns_users()
+```
- has_previous_page: bool = strawberry.field(
- description="When paginating backwards, are there more items?"
- )
+This would generate a schema like this:
- start_cursor: Optional[str] = strawberry.field(
- description="When paginating backwards, the cursor to continue."
- )
+```graphql
+type PageInfo {
+ hasNextPage: Boolean!
+ hasPreviousPage: Boolean!
+ startCursor: String
+ endCursor: String
+}
- end_cursor: Optional[str] = strawberry.field(
- description="When paginating forwards, the cursor to continue."
- )
+type User {
+ id: ID!
+ name: String!
+ occupation: String!
+ age: Int!
+}
+type UserEdge {
+ cursor: String!
+ node: User!
+}
-@strawberry.type
-class Edge(Generic[GenericType]):
- node: GenericType = strawberry.field(description="The item at the end of the edge.")
+type UserConnection {
+ pageInfo: PageInfo!
+ edges: [UserEdge!]!
+}
- cursor: str = strawberry.field(description="A cursor for use in pagination.")
+type Query {
+ getUsers(
+ first: Int = null
+ last: Int = null
+ before: String = null
+ after: String = null
+ ): UserConnection!
+}
```
-EdgeTypes must have atleast two fields - `cursor` and `node`. Each edge has it's
-own cursor and item (represented by the `node` field).
+## ListConnection
-Now that we have the types needed to implement pagination using Relay
-Connections, let us use them to paginate a list of users. For simplicity's sake,
-let our dataset be a list of dictionaries.
+Strawberry also provides `ListConnection`, a subclass of `Connection` that
+implementes a limit/offset pagination algorithm by using slices.
-```py line=7-32
-# example.py
+If a limit/offset pagination is enough for your needs, the above example can be
+simplified to use `ListConnection` to a basic resolver that returns one of:
-from typing import Generic, TypeVar
+- `List[]`
+- `Iterator[]`
+- `Iterable[]`
+- `AsyncIterator[]`
+- `AsyncIterable[]`
+- `Generator[, Any, Any]`
+- `AsyncGenerator[, Any]`
-import strawberry
+For example:
-user_data = [
- {
- "id": 1,
- "name": "Norman Osborn",
- "occupation": "Founder, Oscorp Industries",
- "age": 42,
- },
- {
- "id": 2,
- "name": "Peter Parker",
- "occupation": "Freelance Photographer, The Daily Bugle",
- "age": 20,
- },
- {
- "id": 3,
- "name": "Harold Osborn",
- "occupation": "President, Oscorp Industries",
- "age": 19,
- },
- {
- "id": 4,
- "name": "Eddie Brock",
- "occupation": "Journalist, The Eddie Brock Report",
- "age": 20,
- },
-]
-
-
-GenericType = TypeVar("GenericType")
+```python
+import strawberry
+from strawberry.pagination import Connection, Edge, to_base64
@strawberry.type
-class Connection(Generic[GenericType]):
- page_info: "PageInfo" = strawberry.field(
- description="Information to aid in pagination."
- )
+class Query:
+ @connection(ListConnection[User])
+ def get_users(self) -> Iterable[User]:
+ # This can be a database query, a generator, an async generator, etc
+ return some_function_that_returns_users()
+```
- edges: list["Edge[GenericType]"] = strawberry.field(
- description="A list of edges in this connection."
- )
+Because the implementation will use a slice to paginate the data, that means you
+can override what the slice does by customizing the `__getitem__` method of the
+object returned by your nodes resolver.
+For example, when working with `Django`, `resolve_nodes` can return a
+`QuerySet`, meaning that the slice on it will translate to a `LIMIT`/`OFFSET` in
+the SQL query, making it fetch only the data that is needed from the database.
-@strawberry.type
-class PageInfo:
- has_next_page: bool = strawberry.field(
- description="When paginating forwards, are there more items?"
- )
+Also note that if that object doesn't have a `__getitem__` attribute, it will
+use `itertools.islice` to paginate it, meaning that when a generator is being
+resolved it will only generate as much results as needed for the given
+pagination, the worst case scenario being the last results needing to be
+returned.
- has_previous_page: bool = strawberry.field(
- description="When paginating backwards, are there more items?"
- )
+## Custom Connection Arguments
- start_cursor: Optional[str] = strawberry.field(
- description="When paginating backwards, the cursor to continue."
- )
+By default the connection will automatically insert some arguments for it to be
+able to paginate the results. Those are:
- end_cursor: Optional[str] = strawberry.field(
- description="When paginating forwards, the cursor to continue."
- )
+- `before`: Returns the items in the list that come before the specified cursor
+- `after`: Returns the items in the list that come after the " "specified cursor
+- `first`: Returns the first n items from the list
+- `last`: Returns the items in the list that come after the " "specified cursor
+You can still define extra arguments to be used by your own resolver or custom
+pagination logic, and those will be merged together. For example, suppose we
+want to return the pagination of all users whose name starts with a given
+string. We could do that like this:
+```python
@strawberry.type
-class Edge(Generic[GenericType]):
- node: GenericType = strawberry.field(description="The item at the end of the edge.")
-
- cursor: str = strawberry.field(description="A cursor for use in pagination.")
+class Query:
+ @connection(ListConnection[User])
+ def get_users(self, name_starswith: str) -> Iterable[User]:
+ return some_function_that_returns_users(name_startswith=name_startswith)
```
-Now is a good time to think of what we could use as a cursor for our dataset.
-Our cursor needs to be an opaque value, which doesn't usually change over time.
-It makes sense to use base64 encoded IDs of users as our cursor, as they fit
-both criteria.
-
-
+This will generate a `Query` like this:
-While working with Connections, it is a convention to base64-encode cursors. It
-provides a unified interface to the end user. API clients need not bother about
-the type of data to paginate, and can pass unique IDs during pagination. It also
-makes the cursors opaque.
-
-
+```graphql
+type Query {
+ getUsers(
+ nameStartswith: String!
+ first: Int = null
+ last: Int = null
+ before: String = null
+ after: String = null
+ ): UserConnection!
+}
+```
-Let us define a couple of helper functions to encode and decode cursors as
-follows:
+## Converting the node to its proper type when resolving the connection
-```py line=3,35-43
-# example.py
+The connection expects that the resolver will return a list of objects that is a
+subclass of its `NodeType`. But there may be situations where you are resolving
+something that needs to be converted to the proper type, like an ORM model.
-from base64 import b64encode, b64decode
-from typing import Generic, TypeVar
+In this case you can subclass the `Connection`/`ListConnection` and provide a
+custom `resolve_node` method to it, which by default returns the node as is. For
+example:
+```python
import strawberry
+from strawberry.pagination import ListConnection, connection
-user_data = [
- {
- "id": 1,
- "name": "Norman Osborn",
- "occupation": "Founder, Oscorp Industries",
- "age": 42,
- },
- {
- "id": 2,
- "name": "Peter Parker",
- "occupation": "Freelance Photographer, The Daily Bugle",
- "age": 20,
- },
- {
- "id": 3,
- "name": "Harold Osborn",
- "occupation": "President, Oscorp Industries",
- "age": 19,
- },
- {
- "id": 4,
- "name": "Eddie Brock",
- "occupation": "Journalist, The Eddie Brock Report",
- "age": 20,
- },
-]
-
-
-def encode_user_cursor(id: int) -> str:
- """
- Encodes the given user ID into a cursor.
-
- :param id: The user ID to encode.
-
- :return: The encoded cursor.
- """
- return b64encode(f"user:{id}".encode("ascii")).decode("ascii")
-
-
-def decode_user_cursor(cursor: str) -> int:
- """
- Decodes the user ID from the given cursor.
-
- :param cursor: The cursor to decode.
-
- :return: The decoded user ID.
- """
- cursor_data = b64decode(cursor.encode("ascii")).decode("ascii")
- return int(cursor_data.split(":")[1])
-
-
-GenericType = TypeVar("GenericType")
+from db.models import UserModel
@strawberry.type
-class Connection(Generic[GenericType]):
- page_info: "PageInfo" = strawberry.field(
- description="Information to aid in pagination."
- )
-
- edges: list["Edge[GenericType]"] = strawberry.field(
- description="A list of edges in this connection."
- )
+class User:
+ id: int
+ name: str
@strawberry.type
-class PageInfo:
- has_next_page: bool = strawberry.field(
- description="When paginating forwards, are there more items?"
- )
-
- has_previous_page: bool = strawberry.field(
- description="When paginating backwards, are there more items?"
- )
-
- start_cursor: Optional[str] = strawberry.field(
- description="When paginating backwards, the cursor to continue."
- )
-
- end_cursor: Optional[str] = strawberry.field(
- description="When paginating forwards, the cursor to continue."
- )
+class UserConnection(ListConnection[User]):
+ @classmethod
+ def resolve_node(cls, node: UserModel, *, info, **kwargs) -> User:
+ return User(
+ id=node.id,
+ name=node.name,
+ )
@strawberry.type
-class Edge(Generic[GenericType]):
- node: GenericType = strawberry.field(description="The item at the end of the edge.")
-
- cursor: str = strawberry.field(description="A cursor for use in pagination.")
+class Query:
+ @connection(UserConnection)
+ def get_users(self, info: strawberry.Info) -> Iterable[UserDB]:
+ return UserDB.objects.all()
```
-Let us define a `get_users` field which returns a connection of users, as well
-as an `UserType`. Let us also plug our query into a schema.
+The main advantage of this approach instead of converting it inside the custom
+resolver is that the `Connection` will paginate the `QuerySet` first, which in
+case of Django will make sure that only the paginated results are fetched from
+the database. After that, the `resolve_node` function will be called for each
+result to retrieve the correct object for it.
+
+We used Django for this example, but the same applies to any other other similar
+use case, like SQLAlchemy, etc.
-```python line=104-174
-# example.py
+## Full working example
-from base64 import b64encode, b64decode
-from typing import List, Optional, Generic, TypeVar
+Here is a full working example of a connection of users which you can play with:
+```python
+from typing import Iterable
import strawberry
+from strawberry.pagination import ListConnection, connection
user_data = [
{
@@ -445,152 +367,28 @@ user_data = [
]
-def encode_user_cursor(id: int) -> str:
- """
- Encodes the given user ID into a cursor.
-
- :param id: The user ID to encode.
-
- :return: The encoded cursor.
- """
- return b64encode(f"user:{id}".encode("ascii")).decode("ascii")
-
-
-def decode_user_cursor(cursor: str) -> int:
- """
- Decodes the user ID from the given cursor.
-
- :param cursor: The cursor to decode.
-
- :return: The decoded user ID.
- """
- cursor_data = b64decode(cursor.encode("ascii")).decode("ascii")
- return int(cursor_data.split(":")[1])
-
-
-GenericType = TypeVar("GenericType")
-
-
-@strawberry.type
-class Connection(Generic[GenericType]):
- page_info: "PageInfo" = strawberry.field(
- description="Information to aid in pagination."
- )
-
- edges: list["Edge[GenericType]"] = strawberry.field(
- description="A list of edges in this connection."
- )
-
-
-@strawberry.type
-class PageInfo:
- has_next_page: bool = strawberry.field(
- description="When paginating forwards, are there more items?"
- )
-
- has_previous_page: bool = strawberry.field(
- description="When paginating backwards, are there more items?"
- )
-
- start_cursor: Optional[str] = strawberry.field(
- description="When paginating backwards, the cursor to continue."
- )
-
- end_cursor: Optional[str] = strawberry.field(
- description="When paginating forwards, the cursor to continue."
- )
-
-
-@strawberry.type
-class Edge(Generic[GenericType]):
- node: GenericType = strawberry.field(description="The item at the end of the edge.")
-
- cursor: str = strawberry.field(description="A cursor for use in pagination.")
-
-
@strawberry.type
class User:
- id: int = strawberry.field(description="The id of the user.")
-
- name: str = strawberry.field(description="The name of the user.")
-
- occupation: str = strawberry.field(description="The occupation of the user.")
-
- age: int = strawberry.field(description="The age of the user.")
+ id: int
+ name: str
+ occupation: str
+ age: int
@strawberry.type
class Query:
- @strawberry.field(description="Get a list of users.")
- def get_users(
- self, first: int = 2, after: Optional[str] = None
- ) -> Connection[User]:
- if after is not None:
- # decode the user ID from the given cursor.
- user_id = decode_user_cursor(cursor=after)
- else:
- # no cursor was given (this happens usually when the
- # client sends a query for the first time).
- user_id = 0
-
- # filter the user data, going through the next set of results.
- filtered_data = list(filter(lambda user: user["id"] > user_id, user_data))
-
- # slice the relevant user data (Here, we also slice an
- # additional user instance, to prepare the next cursor).
- sliced_users = filtered_data[: first + 1]
-
- if len(sliced_users) > first:
- # calculate the client's next cursor.
- last_user = sliced_users.pop(-1)
- next_cursor = encode_user_cursor(id=last_user["id"])
- has_next_page = True
- else:
- # We have reached the last page, and
- # don't have the next cursor.
- next_cursor = None
- has_next_page = False
-
- # We know that we have items in the
- # previous page window if the initial user ID
- # was not the first one.
- has_previous_page = user_id > 0
-
- # build user edges.
- edges = [
- Edge(
- node=User(**user),
- cursor=encode_user_cursor(id=user["id"]),
+ @connection(ListConnection[User])
+ def get_users(self) -> Iterable[User]:
+ return [
+ User(
+ id=user["id"],
+ name=user["name"],
+ occupation=user["occupation"],
+ age=user["age"],
)
- for user in sliced_users
+ for user in user_data
]
- if edges:
- # we have atleast one edge. Get the cursor
- # of the first edge we have.
- start_cursor = edges[0].cursor
- else:
- # We have no edges to work with.
- start_cursor = None
-
- if len(edges) > 1:
- # We have atleast 2 edges. Get the cursor
- # of the last edge we have.
- end_cursor = edges[-1].cursor
- else:
- # We don't have enough edges to work with.
- end_cursor = None
-
- return Connection(
- edges=edges,
- page_info=PageInfo(
- has_next_page=has_next_page,
- has_previous_page=has_previous_page,
- start_cursor=start_cursor,
- end_cursor=end_cursor,
- ),
- )
-
schema = strawberry.Schema(query=Query)
```
diff --git a/docs/guides/pagination/overview.md b/docs/guides/pagination/overview.md
index 8b309575f8..9938436415 100644
--- a/docs/guides/pagination/overview.md
+++ b/docs/guides/pagination/overview.md
@@ -136,8 +136,8 @@ consistent way to handle pagination.
Strawberry provides a cursor based pagination implementing the
-[relay spec](https://relay.dev/docs/guides/graphql-server-specification/). You
-can read more about it in the [relay](../relay) page.
+[connection](https://relay.dev/docs/guides/graphql-server-specification/). You
+can read more about it in the [connections](./connections) page.
@@ -203,4 +203,4 @@ Let us look at how we can implement pagination in GraphQL.
- [Implementing Offset Pagination](./offset-based.md)
- [Implementing Cursor Pagination](./cursor-based.md)
-- [Implementing the Relay Connection Specification](./connections.md)
+- [Implementing the Connection Specification](./connections.md)
diff --git a/docs/guides/relay.md b/docs/guides/relay.md
index 394583abcf..2b78cea8da 100644
--- a/docs/guides/relay.md
+++ b/docs/guides/relay.md
@@ -82,14 +82,15 @@ string `:`. In the example above, the `Fruit` with a code of
-Now we can expose it in the schema for retrieval and pagination like:
+Now we can expose it in the schema for retrieval and
+[pagination](./pagination/connections) like:
```python
@strawberry.type
class Query:
node: relay.Node = relay.node()
- @relay.connection(relay.ListConnection[Fruit])
+ @connection(ListConnection[Fruit])
def fruits(self) -> Iterable[Fruit]:
# This can be a database query, a generator, an async generator, etc
return all_fruits.values()
@@ -179,16 +180,6 @@ query {
}
```
-The connection resolver for `relay.ListConnection` should return one of those:
-
-- `List[]`
-- `Iterator[]`
-- `Iterable[]`
-- `AsyncIterator[]`
-- `AsyncIterable[]`
-- `Generator[, Any, Any]`
-- `AsyncGenerator[, Any]`
-
### The node field
As demonstrated above, the `Node` field can be used to retrieve/refetch any
@@ -205,207 +196,14 @@ It can be defined in the `Query` objects in 4 ways:
- `node: List[Optional[Node]]`: The same as `List[Node]`, but the returned list
can contain `null` values if the given objects don't exist.
-### Custom connection pagination
-
-The default `relay.Connection` class don't implement any pagination logic, and
-should be used as a base class to implement your own pagination logic. All you
-need to do is implement the `resolve_connection` classmethod.
-
-The integration provides `relay.ListConnection`, which implements a limit/offset
-approach to paginate the results. This is a basic approach and might be enough
-for most use cases.
-
-
-
-`relay.ListConnection` implementes the limit/offset by using slices. That means
-that you can override what the slice does by customizing the `__getitem__`
-method of the object returned by your nodes resolver.
-
-For example, when working with `Django`, `resolve_nodes` can return a
-`QuerySet`, meaning that the slice on it will translate to a `LIMIT`/`OFFSET` in
-the SQL query, which will fetch only the data that is needed from the database.
-
-Also note that if that object doesn't have a `__getitem__` attribute, it will
-use `itertools.islice` to paginate it, meaning that when a generator is being
-resolved it will only generate as much results as needed for the given
-pagination, the worst case scenario being the last results needing to be
-returned.
-
-
-
-Now, suppose we want to implement a custom cursor-based pagination for our
-previous example. We can do something like this:
-
-```python
-import strawberry
-from strawberry import relay
-
-
-@strawberry.type
-class FruitCustomPaginationConnection(relay.Connection[Fruit]):
- @classmethod
- def resolve_connection(
- cls,
- nodes: Iterable[Fruit],
- *,
- info: Optional[Info] = None,
- before: Optional[str] = None,
- after: Optional[str] = None,
- first: Optional[int] = None,
- last: Optional[int] = None,
- ):
- # NOTE: This is a showcase implementation and is far from
- # being optimal performance wise
- edges_mapping = {
- relay.to_base64("fruit_name", n.name): relay.Edge(
- node=n,
- cursor=relay.to_base64("fruit_name", n.name),
- )
- for n in sorted(nodes, key=lambda f: f.name)
- }
- edges = list(edges_mapping.values())
- first_edge = edges[0] if edges else None
- last_edge = edges[-1] if edges else None
-
- if after is not None:
- after_edge_idx = edges.index(edges_mapping[after])
- edges = [e for e in edges if edges.index(e) > after_edge_idx]
-
- if before is not None:
- before_edge_idx = edges.index(edges_mapping[before])
- edges = [e for e in edges if edges.index(e) < before_edge_idx]
-
- if first is not None:
- edges = edges[:first]
-
- if last is not None:
- edges = edges[-last:]
-
- return cls(
- edges=edges,
- page_info=strawberry.relay.PageInfo(
- start_cursor=edges[0].cursor if edges else None,
- end_cursor=edges[-1].cursor if edges else None,
- has_previous_page=(
- first_edge is not None and bool(edges) and edges[0] != first_edge
- ),
- has_next_page=(
- last_edge is not None and bool(edges) and edges[-1] != last_edge
- ),
- ),
- )
-
-
-@strawberry.type
-class Query:
- @relay.connection(FruitCustomPaginationConnection)
- def fruits(self) -> Iterable[Fruit]:
- # This can be a database query, a generator, an async generator, etc
- return all_fruits.values()
-```
-
-In the example above we specialized the `FruitCustomPaginationConnection` by
-inheriting it from `relay.Connection[Fruit]`. We could still keep it generic by
-inheriting it from `relay.Connection[relay.NodeType]` and then specialize it
-when defining the field, making it possible to use our custom pagination logic
-with more than one type.
-
-### Custom connection arguments
-
-By default the connection will automatically insert some arguments for it to be
-able to paginate the results. Those are:
-
-- `before`: Returns the items in the list that come before the specified cursor
-- `after`: Returns the items in the list that come after the " "specified cursor
-- `first`: Returns the first n items from the list
-- `last`: Returns the items in the list that come after the " "specified cursor
-
-You can still define extra arguments to be used by your own resolver or custom
-pagination logic. For example, suppose we want to return the pagination of all
-fruits whose name starts with a given string. We could do that like this:
-
-```python
-@strawberry.type
-class Query:
- @relay.connection(relay.ListConnection[Fruit])
- def fruits_with_filter(
- self,
- info: strawberry.Info,
- name_endswith: str,
- ) -> Iterable[Fruit]:
- for f in fruits.values():
- if f.name.endswith(name_endswith):
- yield f
-```
-
-This will generate a schema like this:
-
-```graphql
-type Query {
- fruitsWithFilter(
- nameEndswith: String!
- before: String = null
- after: String = null
- first: Int = null
- last: Int = null
- ): FruitConnection!
-}
-```
-
-### Convert the node to its proper type when resolving the connection
-
-The connection expects that the resolver will return a list of objects that is a
-subclass of its `NodeType`. But there may be situations where you are resolving
-something that needs to be converted to the proper type, like an ORM model.
-
-In this case you can subclass the `relay.Connection`/`relay.ListConnection` and
-provide a custom `resolve_node` method to it, which by default returns the node
-as is. For example:
-
-```python
-import strawberry
-from strawberry import relay
-
-from db import models
-
-
-@strawberry.type
-class Fruit(relay.Node):
- code: relay.NodeID[int]
- name: str
- weight: float
-
-
-@strawberry.type
-class FruitDBConnection(relay.ListConnection[Fruit]):
- @classmethod
- def resolve_node(cls, node: FruitDB, *, info: strawberry.Info, **kwargs) -> Fruit:
- return Fruit(
- code=node.code,
- name=node.name,
- weight=node.weight,
- )
-
-
-@strawberry.type
-class Query:
- @relay.connection(FruitDBConnection)
- def fruits_with_filter(
- self,
- info: strawberry.Info,
- name_endswith: str,
- ) -> Iterable[models.Fruit]:
- return models.Fruit.objects.filter(name__endswith=name_endswith)
-```
+### The connection field
-The main advantage of this approach instead of converting it inside the custom
-resolver is that the `Connection` will paginate the `QuerySet` first, which in
-case of django will make sure that only the paginated results are fetched from
-the database. After that, the `resolve_node` function will be called for each
-result to retrieve the correct object for it.
+The connection field, although defined in the relay specifications, is actually
+a well-known pattern in GraphQL for getting paginated results.
-We used django for this example, but the same applies to any other other similar
-use case, like SQLAlchemy, etc.
+Because of that, it is implemented in Strawberry as a generic pagination
+solution, which you can read more about in the
+[pagination](./pagination/connections) page.
### The GlobalID scalar
diff --git a/strawberry/pagination/__init__.py b/strawberry/pagination/__init__.py
new file mode 100644
index 0000000000..cf7b142207
--- /dev/null
+++ b/strawberry/pagination/__init__.py
@@ -0,0 +1,21 @@
+from .fields import ConnectionExtension, connection
+from .types import (
+ Connection,
+ Edge,
+ ListConnection,
+ NodeType,
+ PageInfo,
+)
+from .utils import from_base64, to_base64
+
+__all__ = [
+ "Connection",
+ "ConnectionExtension",
+ "Edge",
+ "ListConnection",
+ "NodeType",
+ "PageInfo",
+ "connection",
+ "from_base64",
+ "to_base64",
+]
diff --git a/strawberry/pagination/exceptions.py b/strawberry/pagination/exceptions.py
new file mode 100644
index 0000000000..6e9c0bcc74
--- /dev/null
+++ b/strawberry/pagination/exceptions.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+from collections.abc import Callable
+from functools import cached_property
+from typing import TYPE_CHECKING, Optional, Type, cast
+
+from strawberry.exceptions.exception import StrawberryException
+from strawberry.exceptions.utils.source_finder import SourceFinder
+
+if TYPE_CHECKING:
+ from strawberry.exceptions.exception_source import ExceptionSource
+ from strawberry.types.fields.resolver import StrawberryResolver
+
+
+class ConnectionWrongAnnotationError(StrawberryException):
+ def __init__(self, field_name: str, cls: Type) -> None:
+ self.cls = cls
+ self.field_name = field_name
+
+ self.message = (
+ f'Wrong annotation used on field "{field_name}". It should be '
+ 'annotated with a "Connection" subclass.'
+ )
+ self.rich_message = (
+ f"Wrong annotation for field `[underline]{self.field_name}[/]`"
+ )
+ self.suggestion = (
+ "To fix this error you can add a valid annotation, "
+ f"like [italic]`{self.field_name}: Connection[{cls}]` "
+ f"or [italic]`@connection(Connection[{cls}])`"
+ )
+ self.annotation_message = "connection wrong annotation"
+
+ super().__init__(self.message)
+
+ @cached_property
+ def exception_source(self) -> Optional[ExceptionSource]:
+ if self.cls is None:
+ return None # pragma: no cover
+
+ source_finder = SourceFinder()
+ return source_finder.find_class_attribute_from_object(self.cls, self.field_name)
+
+
+class ConnectionWrongResolverAnnotationError(StrawberryException):
+ def __init__(self, field_name: str, resolver: StrawberryResolver) -> None:
+ self.function = resolver.wrapped_func
+ self.field_name = field_name
+
+ self.message = (
+ f'Wrong annotation used on "{field_name}" resolver. '
+ "It should be return an iterable or async iterable object."
+ )
+ self.rich_message = (
+ f"Wrong annotation used on `{field_name}` resolver. "
+ "It should be return an `iterable` or `async iterable` object."
+ )
+ self.suggestion = (
+ "To fix this error you can annootate your resolver to return "
+ "one of the following options: `List[]`, "
+ "`Iterator[]`, `Iterable[]`, "
+ "`AsyncIterator[]`, `AsyncIterable[]`, "
+ "`Generator[, Any, Any]` and "
+ "`AsyncGenerator[, Any]`."
+ )
+ self.annotation_message = "connection wrong resolver annotation"
+
+ super().__init__(self.message)
+
+ @cached_property
+ def exception_source(self) -> Optional[ExceptionSource]:
+ if self.function is None:
+ return None # pragma: no cover
+
+ source_finder = SourceFinder()
+ return source_finder.find_function_from_object(cast(Callable, self.function))
+
+
+__all__ = [
+ "ConnectionWrongAnnotationError",
+ "ConnectionWrongResolverAnnotationError",
+]
diff --git a/strawberry/pagination/fields.py b/strawberry/pagination/fields.py
new file mode 100644
index 0000000000..3e58c6a0a3
--- /dev/null
+++ b/strawberry/pagination/fields.py
@@ -0,0 +1,338 @@
+from __future__ import annotations
+
+import dataclasses
+import inspect
+from collections.abc import AsyncIterable
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncIterator,
+ Callable,
+ ForwardRef,
+ Iterable,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Type,
+ Union,
+ cast,
+ overload,
+)
+from typing_extensions import get_origin
+
+from strawberry.annotation import StrawberryAnnotation
+from strawberry.extensions.field_extension import (
+ AsyncExtensionResolver,
+ FieldExtension,
+ SyncExtensionResolver,
+)
+from strawberry.types.arguments import StrawberryArgument
+from strawberry.types.field import _RESOLVER_TYPE, StrawberryField
+from strawberry.types.lazy_type import LazyType
+from strawberry.utils.typing import eval_type, is_generic_alias
+
+from .exceptions import (
+ ConnectionWrongAnnotationError,
+ ConnectionWrongResolverAnnotationError,
+)
+from .types import Connection, NodeIterableType, NodeType
+
+if TYPE_CHECKING:
+ from typing_extensions import Literal
+
+ from strawberry.permission import BasePermission
+ from strawberry.types.info import Info
+
+
+class ConnectionExtension(FieldExtension):
+ connection_type: Type[Connection]
+
+ def apply(self, field: StrawberryField) -> None:
+ field.arguments = [
+ *field.arguments,
+ StrawberryArgument(
+ python_name="before",
+ graphql_name=None,
+ type_annotation=StrawberryAnnotation(Optional[str]),
+ description=(
+ "Returns the items in the list that come before the "
+ "specified cursor."
+ ),
+ default=None,
+ ),
+ StrawberryArgument(
+ python_name="after",
+ graphql_name=None,
+ type_annotation=StrawberryAnnotation(Optional[str]),
+ description=(
+ "Returns the items in the list that come after the "
+ "specified cursor."
+ ),
+ default=None,
+ ),
+ StrawberryArgument(
+ python_name="first",
+ graphql_name=None,
+ type_annotation=StrawberryAnnotation(Optional[int]),
+ description="Returns the first n items from the list.",
+ default=None,
+ ),
+ StrawberryArgument(
+ python_name="last",
+ graphql_name=None,
+ type_annotation=StrawberryAnnotation(Optional[int]),
+ description=(
+ "Returns the items in the list that come after the "
+ "specified cursor."
+ ),
+ default=None,
+ ),
+ ]
+
+ f_type = field.type
+
+ if isinstance(f_type, LazyType):
+ f_type = f_type.resolve_type()
+ field.type = f_type
+
+ type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type
+ if not isinstance(type_origin, type) or not issubclass(type_origin, Connection):
+ raise ConnectionWrongAnnotationError(field.name, cast(type, field.origin))
+
+ assert field.base_resolver
+ # TODO: We are not using resolver_type.type because it will call
+ # StrawberryAnnotation.resolve, which will strip async types from the
+ # type (i.e. AsyncGenerator[Fruit] will become Fruit). This is done there
+ # for subscription support, but we can't use it here. Maybe we can refactor
+ # this in the future.
+ resolver_type = field.base_resolver.signature.return_annotation
+ if isinstance(resolver_type, str):
+ resolver_type = ForwardRef(resolver_type)
+ if isinstance(resolver_type, ForwardRef):
+ resolver_type = eval_type(
+ resolver_type,
+ field.base_resolver._namespace,
+ None,
+ )
+
+ origin = get_origin(resolver_type)
+ if origin is None or not issubclass(
+ origin, (Iterator, Iterable, AsyncIterator, AsyncIterable)
+ ):
+ raise ConnectionWrongResolverAnnotationError(
+ field.name, field.base_resolver
+ )
+
+ self.connection_type = cast(Type[Connection], field.type)
+
+ def resolve(
+ self,
+ next_: SyncExtensionResolver,
+ source: Any,
+ info: Info,
+ *,
+ before: Optional[str] = None,
+ after: Optional[str] = None,
+ first: Optional[int] = None,
+ last: Optional[int] = None,
+ **kwargs: Any,
+ ) -> Any:
+ assert self.connection_type is not None
+ return self.connection_type.resolve_connection(
+ cast(Iterable, next_(source, info, **kwargs)),
+ info=info,
+ before=before,
+ after=after,
+ first=first,
+ last=last,
+ )
+
+ async def resolve_async(
+ self,
+ next_: AsyncExtensionResolver,
+ source: Any,
+ info: Info,
+ *,
+ before: Optional[str] = None,
+ after: Optional[str] = None,
+ first: Optional[int] = None,
+ last: Optional[int] = None,
+ **kwargs: Any,
+ ) -> Any:
+ assert self.connection_type is not None
+ nodes = next_(source, info, **kwargs)
+ # nodes might be an AsyncIterable/AsyncIterator
+ # In this case we don't await for it
+ if inspect.isawaitable(nodes):
+ nodes = await nodes
+
+ resolved = self.connection_type.resolve_connection(
+ cast(Iterable, nodes),
+ info=info,
+ before=before,
+ after=after,
+ first=first,
+ last=last,
+ )
+
+ # If nodes was an AsyncIterable/AsyncIterator, resolve_connection
+ # will return a coroutine which we need to await
+ if inspect.isawaitable(resolved):
+ resolved = await resolved
+ return resolved
+
+
+@overload
+def connection(
+ graphql_type: Optional[Type[Connection[NodeType]]] = None,
+ *,
+ resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None,
+ name: Optional[str] = None,
+ is_subscription: bool = False,
+ description: Optional[str] = None,
+ init: Literal[True] = True,
+ permission_classes: Optional[List[Type[BasePermission]]] = None,
+ deprecation_reason: Optional[str] = None,
+ default: Any = dataclasses.MISSING,
+ default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
+ metadata: Optional[Mapping[Any, Any]] = None,
+ directives: Optional[Sequence[object]] = (),
+ extensions: List[FieldExtension] = (), # type: ignore
+) -> Any: ...
+
+
+@overload
+def connection(
+ graphql_type: Optional[Type[Connection[NodeType]]] = None,
+ *,
+ name: Optional[str] = None,
+ is_subscription: bool = False,
+ description: Optional[str] = None,
+ permission_classes: Optional[List[Type[BasePermission]]] = None,
+ deprecation_reason: Optional[str] = None,
+ default: Any = dataclasses.MISSING,
+ default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
+ metadata: Optional[Mapping[Any, Any]] = None,
+ directives: Optional[Sequence[object]] = (),
+ extensions: List[FieldExtension] = (), # type: ignore
+) -> StrawberryField: ...
+
+
+def connection(
+ graphql_type: Optional[Type[Connection[NodeType]]] = None,
+ *,
+ resolver: Optional[_RESOLVER_TYPE[Any]] = None,
+ name: Optional[str] = None,
+ is_subscription: bool = False,
+ description: Optional[str] = None,
+ permission_classes: Optional[List[Type[BasePermission]]] = None,
+ deprecation_reason: Optional[str] = None,
+ default: Any = dataclasses.MISSING,
+ default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
+ metadata: Optional[Mapping[Any, Any]] = None,
+ directives: Optional[Sequence[object]] = (),
+ extensions: List[FieldExtension] = (), # type: ignore
+ # This init parameter is used by pyright to determine whether this field
+ # is added in the constructor or not. It is not used to change
+ # any behavior at the moment.
+ init: Literal[True, False, None] = None,
+) -> Any:
+ """Annotate a property or a method to create a connection field.
+
+ Connections are mostly used for pagination purposes. This decorator
+ helps creating a complete connection endpoint that provides default arguments
+ and has a default implementation for the connection slicing.
+
+ Note that when setting a resolver to this field, it is expected for this
+ resolver to return an iterable of the expected node type, not the connection
+ itself. That iterable will then be paginated accordingly. So, the main use
+ case for this is to provide a filtered iterable of nodes by using some custom
+ filter arguments.
+
+ Args:
+ graphql_type: The type of the nodes in the connection. This is used to
+ determine the type of the edges and the node field in the connection.
+ resolver: The resolver for the connection. This is expected to return an
+ iterable of the expected node type.
+ name: The GraphQL name of the field.
+ is_subscription: Whether the field is a subscription.
+ description: The GraphQL description of the field.
+ permission_classes: The permission classes to apply to the field.
+ deprecation_reason: The deprecation reason of the field.
+ default: The default value of the field.
+ default_factory: The default factory of the field.
+ metadata: The metadata of the field.
+ directives: The directives to apply to the field.
+ extensions: The extensions to apply to the field.
+ init: Used only for type checking purposes.
+
+ Examples:
+ Annotating something like this:
+
+ ```python
+ @strawberry.type
+ class X:
+ some_node: Connection[SomeType] = connection(
+ resolver=get_some_nodes,
+ description="ABC",
+ )
+
+ @connection(Connection[SomeType], description="ABC")
+ def get_some_nodes(self, age: int) -> Iterable[SomeType]: ...
+ ```
+
+ Will produce a query like this:
+
+ ```graphql
+ query {
+ someNode (
+ before: String
+ after: String
+ first: String
+ after: String
+ age: Int
+ ) {
+ totalCount
+ pageInfo {
+ hasNextPage
+ hasPreviousPage
+ startCursor
+ endCursor
+ }
+ edges {
+ cursor
+ node {
+ id
+ ...
+ }
+ }
+ }
+ }
+ ```
+
+ .. _Connections:
+ https://relay.dev/graphql/connections.htm
+
+ """
+ f = StrawberryField(
+ python_name=None,
+ graphql_name=name,
+ description=description,
+ type_annotation=StrawberryAnnotation.from_annotation(graphql_type),
+ is_subscription=is_subscription,
+ permission_classes=permission_classes or [],
+ deprecation_reason=deprecation_reason,
+ default=default,
+ default_factory=default_factory,
+ metadata=metadata,
+ directives=directives or (),
+ extensions=[*extensions, ConnectionExtension()],
+ )
+ if resolver is not None:
+ f = f(resolver)
+ return f
+
+
+__all__ = ["ConnectionExtension", "connection"]
diff --git a/strawberry/pagination/types.py b/strawberry/pagination/types.py
new file mode 100644
index 0000000000..3e8c4fec8c
--- /dev/null
+++ b/strawberry/pagination/types.py
@@ -0,0 +1,376 @@
+from __future__ import annotations
+
+import itertools
+import sys
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncIterable,
+ AsyncIterator,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ TypeVar,
+ Union,
+ cast,
+)
+from typing_extensions import Self, TypeAlias
+
+from strawberry.types.base import (
+ StrawberryContainer,
+ get_object_definition,
+)
+from strawberry.types.field import field
+from strawberry.types.info import Info # noqa: TCH001
+from strawberry.types.object_type import type
+from strawberry.utils.aio import aenumerate, aislice
+from strawberry.utils.inspect import in_async_context
+
+from .utils import (
+ SliceMetadata,
+ should_resolve_list_connection_edges,
+ to_base64,
+)
+
+if TYPE_CHECKING:
+ from strawberry.utils.await_maybe import AwaitableOrValue
+
+NodeType = TypeVar("NodeType")
+NodeIterableType: TypeAlias = Union[
+ Iterator[NodeType],
+ Iterable[NodeType],
+ AsyncIterator[NodeType],
+ AsyncIterable[NodeType],
+]
+
+PREFIX = "arrayconnection"
+
+
+@type(description="Information to aid in pagination.")
+class PageInfo:
+ """Information to aid in pagination.
+
+ Attributes:
+ has_next_page:
+ When paginating forwards, are there more items?
+ has_previous_page:
+ When paginating backwards, are there more items?
+ start_cursor:
+ When paginating backwards, the cursor to continue
+ end_cursor:
+ When paginating forwards, the cursor to continue
+ """
+
+ has_next_page: bool = field(
+ description="When paginating forwards, are there more items?",
+ )
+ has_previous_page: bool = field(
+ description="When paginating backwards, are there more items?",
+ )
+ start_cursor: Optional[str] = field(
+ description="When paginating backwards, the cursor to continue.",
+ )
+ end_cursor: Optional[str] = field(
+ description="When paginating forwards, the cursor to continue.",
+ )
+
+
+@type(description="An edge in a connection.")
+class Edge(Generic[NodeType]):
+ """An edge in a connection.
+
+ Attributes:
+ cursor:
+ A cursor for use in pagination
+ node:
+ The item at the end of the edge
+ """
+
+ cursor: str = field(description="A cursor for use in pagination")
+ node: NodeType = field(description="The item at the end of the edge")
+
+ @classmethod
+ def resolve_edge(cls, node: NodeType, *, cursor: Any = None) -> Self:
+ return cls(cursor=to_base64(PREFIX, cursor), node=node)
+
+
+@type(description="A connection to a list of items.")
+class Connection(Generic[NodeType]):
+ """A connection to a list of items.
+
+ Attributes:
+ page_info:
+ Pagination data for this connection
+ edges:
+ Contains the nodes in this connection
+
+ """
+
+ page_info: PageInfo = field(description="Pagination data for this connection")
+ edges: List[Edge[NodeType]] = field(
+ description="Contains the nodes in this connection"
+ )
+
+ @classmethod
+ def resolve_node(cls, node: Any, *, info: Info, **kwargs: Any) -> NodeType:
+ """The identity function for the node.
+
+ This method is used to resolve a node of a different type to the
+ connection's `NodeType`.
+
+ By default it returns the node itself, but subclasses can override
+ this to provide a custom implementation.
+
+ Args:
+ node:
+ The resolved node which should return an instance of this
+ connection's `NodeType`.
+ info:
+ The strawberry execution info resolve the type name from.
+ **kwargs:
+ Additional arguments passed to the resolver.
+
+ """
+ return node
+
+ @classmethod
+ def resolve_connection(
+ cls,
+ nodes: NodeIterableType[NodeType],
+ *,
+ info: Info,
+ before: Optional[str] = None,
+ after: Optional[str] = None,
+ first: Optional[int] = None,
+ last: Optional[int] = None,
+ **kwargs: Any,
+ ) -> AwaitableOrValue[Self]:
+ """Resolve a connection from nodes.
+
+ Subclasses must define this method to paginate nodes based
+ on `first`/`last`/`before`/`after` arguments.
+
+ Args:
+ info: The strawberry execution info resolve the type name from.
+ nodes: An iterable/iteretor of nodes to paginate.
+ before: Returns the items in the list that come before the specified cursor.
+ after: Returns the items in the list that come after the specified cursor.
+ first: Returns the first n items from the list.
+ last: Returns the items in the list that come after the specified cursor.
+ kwargs: Additional arguments passed to the resolver.
+
+ Returns:
+ The resolved `Connection`
+
+ """
+ raise NotImplementedError
+
+
+@type(name="Connection", description="A connection to a list of items.")
+class ListConnection(Connection[NodeType]):
+ """A connection to a list of items.
+
+ Attributes:
+ page_info:
+ Pagination data for this connection
+ edges:
+ Contains the nodes in this connection
+
+ """
+
+ page_info: PageInfo = field(description="Pagination data for this connection")
+ edges: List[Edge[NodeType]] = field(
+ description="Contains the nodes in this connection"
+ )
+
+ @classmethod
+ def resolve_connection(
+ cls,
+ nodes: NodeIterableType[NodeType],
+ *,
+ info: Info,
+ before: Optional[str] = None,
+ after: Optional[str] = None,
+ first: Optional[int] = None,
+ last: Optional[int] = None,
+ **kwargs: Any,
+ ) -> AwaitableOrValue[Self]:
+ """Resolve a connection from the list of nodes.
+
+ This uses the described Connection Pagination algorithm_
+
+ Args:
+ info: The strawberry execution info resolve the type name from.
+ nodes: An iterable/iteretor of nodes to paginate.
+ before: Returns the items in the list that come before the specified cursor.
+ after: Returns the items in the list that come after the specified cursor.
+ first: Returns the first n items from the list.
+ last: Returns the items in the list that come after the specified cursor.
+ kwargs: Additional arguments passed to the resolver.
+
+ Returns:
+ The resolved `Connection`
+
+ .. _Connection Pagination algorithm:
+ https://relay.dev/graphql/connections.htm#sec-Pagination-algorithm
+ """
+ slice_metadata = SliceMetadata.from_arguments(
+ info,
+ before=before,
+ after=after,
+ first=first,
+ last=last,
+ )
+
+ type_def = get_object_definition(cls)
+ assert type_def
+ field_def = type_def.get_field("edges")
+ assert field_def
+
+ field = field_def.resolve_type(type_definition=type_def)
+ while isinstance(field, StrawberryContainer):
+ field = field.of_type
+
+ edge_class = cast(Edge[NodeType], field)
+
+ if isinstance(nodes, (AsyncIterator, AsyncIterable)) and in_async_context():
+
+ async def resolver() -> Self:
+ try:
+ iterator = cast(
+ Union[AsyncIterator[NodeType], AsyncIterable[NodeType]],
+ cast(Sequence, nodes)[
+ slice_metadata.start : slice_metadata.overfetch
+ ],
+ )
+ except TypeError:
+ # TODO: Why mypy isn't narrowing this based on the if above?
+ assert isinstance(nodes, (AsyncIterator, AsyncIterable))
+ iterator = aislice(
+ nodes,
+ slice_metadata.start,
+ slice_metadata.overfetch,
+ )
+
+ # The slice above might return an object that now is not async
+ # iterable anymore (e.g. an already cached django queryset)
+ if isinstance(iterator, (AsyncIterator, AsyncIterable)):
+ edges: List[Edge] = [
+ edge_class.resolve_edge(
+ cls.resolve_node(v, info=info, **kwargs),
+ cursor=slice_metadata.start + i,
+ )
+ async for i, v in aenumerate(iterator)
+ ]
+ else:
+ edges: List[Edge] = [ # type: ignore[no-redef]
+ edge_class.resolve_edge(
+ cls.resolve_node(v, info=info, **kwargs),
+ cursor=slice_metadata.start + i,
+ )
+ for i, v in enumerate(iterator)
+ ]
+
+ has_previous_page = slice_metadata.start > 0
+ if (
+ slice_metadata.expected is not None
+ and len(edges) == slice_metadata.expected + 1
+ ):
+ # Remove the overfetched result
+ edges = edges[:-1]
+ has_next_page = True
+ elif slice_metadata.end == sys.maxsize:
+ # Last was asked without any after/before
+ assert last is not None
+ original_len = len(edges)
+ edges = edges[-last:]
+ has_next_page = False
+ has_previous_page = len(edges) != original_len
+ else:
+ has_next_page = False
+
+ return cls(
+ edges=edges,
+ page_info=PageInfo(
+ start_cursor=edges[0].cursor if edges else None,
+ end_cursor=edges[-1].cursor if edges else None,
+ has_previous_page=has_previous_page,
+ has_next_page=has_next_page,
+ ),
+ )
+
+ return resolver()
+
+ try:
+ iterator = cast(
+ Union[Iterator[NodeType], Iterable[NodeType]],
+ cast(Sequence, nodes)[slice_metadata.start : slice_metadata.overfetch],
+ )
+ except TypeError:
+ assert isinstance(nodes, (Iterable, Iterator))
+ iterator = itertools.islice(
+ nodes,
+ slice_metadata.start,
+ slice_metadata.overfetch,
+ )
+
+ if not should_resolve_list_connection_edges(info):
+ return cls(
+ edges=[],
+ page_info=PageInfo(
+ start_cursor=None,
+ end_cursor=None,
+ has_previous_page=False,
+ has_next_page=False,
+ ),
+ )
+
+ edges = [
+ edge_class.resolve_edge(
+ cls.resolve_node(v, info=info, **kwargs),
+ cursor=slice_metadata.start + i,
+ )
+ for i, v in enumerate(iterator)
+ ]
+
+ has_previous_page = slice_metadata.start > 0
+ if (
+ slice_metadata.expected is not None
+ and len(edges) == slice_metadata.expected + 1
+ ):
+ # Remove the overfetched result
+ edges = edges[:-1]
+ has_next_page = True
+ elif slice_metadata.end == sys.maxsize:
+ # Last was asked without any after/before
+ assert last is not None
+ original_len = len(edges)
+ edges = edges[-last:]
+ has_next_page = False
+ has_previous_page = len(edges) != original_len
+ else:
+ has_next_page = False
+
+ return cls(
+ edges=edges,
+ page_info=PageInfo(
+ start_cursor=edges[0].cursor if edges else None,
+ end_cursor=edges[-1].cursor if edges else None,
+ has_previous_page=has_previous_page,
+ has_next_page=has_next_page,
+ ),
+ )
+
+
+__all__ = [
+ "NodeIterableType",
+ "NodeType",
+ "PREFIX",
+ "Connection",
+ "Edge",
+ "PageInfo",
+ "ListConnection",
+]
diff --git a/strawberry/pagination/utils.py b/strawberry/pagination/utils.py
new file mode 100644
index 0000000000..34af41f8ca
--- /dev/null
+++ b/strawberry/pagination/utils.py
@@ -0,0 +1,198 @@
+from __future__ import annotations
+
+import base64
+import dataclasses
+import sys
+from typing import TYPE_CHECKING, Any, Tuple, Union
+from typing_extensions import Self, assert_never
+
+from strawberry.types.base import StrawberryObjectDefinition
+from strawberry.types.nodes import InlineFragment, Selection
+
+if TYPE_CHECKING:
+ from strawberry.types.info import Info
+
+
+def from_base64(value: str) -> Tuple[str, str]:
+ """Parse the base64 encoded relay value.
+
+ Args:
+ value:
+ The value to be parsed
+
+ Returns:
+ A tuple of (TypeName, NodeID).
+
+ Raises:
+ ValueError:
+ If the value is not in the expected format
+
+ """
+ try:
+ res = base64.b64decode(value.encode()).decode().split(":", 1)
+ except Exception as e:
+ raise ValueError(str(e)) from e
+
+ if len(res) != 2:
+ raise ValueError(f"{res} expected to contain only 2 items")
+
+ return res[0], res[1]
+
+
+def to_base64(type_: Union[str, type, StrawberryObjectDefinition], node_id: Any) -> str:
+ """Encode the type name and node id to a base64 string.
+
+ Args:
+ type_:
+ The GraphQL type, type definition or type name.
+ node_id:
+ The node id itself
+
+ Returns:
+ A GlobalID, which is a string resulting from base64 encoding :.
+
+ Raises:
+ ValueError:
+ If the value is not a valid GraphQL type or name
+
+ """
+ try:
+ if isinstance(type_, str):
+ type_name = type_
+ elif isinstance(type_, StrawberryObjectDefinition):
+ type_name = type_.name
+ elif isinstance(type_, type):
+ type_name = type_.__strawberry_definition__.name # type:ignore
+ else: # pragma: no cover
+ assert_never(type_)
+ except Exception as e:
+ raise ValueError(f"{type_} is not a valid GraphQL type or name") from e
+
+ return base64.b64encode(f"{type_name}:{node_id}".encode()).decode()
+
+
+def should_resolve_list_connection_edges(info: Info) -> bool:
+ """Check if the user requested to resolve the `edges` field of a connection.
+
+ Args:
+ info:
+ The strawberry execution info resolve the type name from
+
+ Returns:
+ True if the user requested to resolve the `edges` field of a connection, False otherwise.
+
+ """
+ resolve_for_field_names = {"edges", "pageInfo"}
+
+ def _check_selection(selection: Selection) -> bool:
+ """Recursively inspect the selection to check if the user requested to resolve the `edges` field.
+
+ Args:
+ selection (Selection): The selection to check.
+
+ Returns:
+ bool: True if the user requested to resolve the `edges` field of a connection, False otherwise.
+ """
+ if (
+ not isinstance(selection, InlineFragment)
+ and selection.name in resolve_for_field_names
+ ):
+ return True
+ if selection.selections:
+ return any(
+ _check_selection(selection) for selection in selection.selections
+ )
+ return False
+
+ for selection_field in info.selected_fields:
+ for selection in selection_field.selections:
+ if _check_selection(selection):
+ return True
+ return False
+
+
+@dataclasses.dataclass
+class SliceMetadata:
+ start: int
+ end: int
+ expected: int | None
+
+ @property
+ def overfetch(self) -> int:
+ # Overfetch by 1 to check if we have a next result
+ return self.end + 1 if self.end != sys.maxsize else self.end
+
+ @classmethod
+ def from_arguments(
+ cls,
+ info: Info,
+ *,
+ before: str | None = None,
+ after: str | None = None,
+ first: int | None = None,
+ last: int | None = None,
+ ) -> Self:
+ """Get the slice metadata to use on ListConnection."""
+ from strawberry.pagination.types import PREFIX
+
+ max_results = info.schema.config.relay_max_results
+ start = 0
+ end: int | None = None
+
+ if after:
+ after_type, after_parsed = from_base64(after)
+ if after_type != PREFIX:
+ raise TypeError("Argument 'after' contains a non-existing value.")
+
+ start = int(after_parsed) + 1
+ if before:
+ before_type, before_parsed = from_base64(before)
+ if before_type != PREFIX:
+ raise TypeError("Argument 'before' contains a non-existing value.")
+ end = int(before_parsed)
+
+ if isinstance(first, int):
+ if first < 0:
+ raise ValueError("Argument 'first' must be a non-negative integer.")
+
+ if first > max_results:
+ raise ValueError(
+ f"Argument 'first' cannot be higher than {max_results}."
+ )
+
+ if end is not None:
+ start = max(0, end - 1)
+
+ end = start + first
+ if isinstance(last, int):
+ if last < 0:
+ raise ValueError("Argument 'last' must be a non-negative integer.")
+
+ if last > max_results:
+ raise ValueError(
+ f"Argument 'last' cannot be higher than {max_results}."
+ )
+
+ if end is not None:
+ start = max(start, end - last)
+ else:
+ end = sys.maxsize
+
+ if end is None:
+ end = start + max_results
+
+ expected = end - start if end != sys.maxsize else None
+
+ return cls(
+ start=start,
+ end=end,
+ expected=expected,
+ )
+
+
+__all__ = [
+ "from_base64",
+ "to_base64",
+ "should_resolve_list_connection_edges",
+ "SliceMetadata",
+]
diff --git a/strawberry/relay/__init__.py b/strawberry/relay/__init__.py
index c41e777b76..7ed73687b9 100644
--- a/strawberry/relay/__init__.py
+++ b/strawberry/relay/__init__.py
@@ -1,31 +1,54 @@
-from .fields import ConnectionExtension, NodeExtension, connection, node
-from .types import (
+import warnings
+from typing import Any
+
+from strawberry.pagination.fields import ConnectionExtension, connection
+from strawberry.pagination.types import (
Connection,
Edge,
+ ListConnection,
+ NodeType,
+ PageInfo,
+)
+from strawberry.pagination.utils import from_base64, to_base64
+
+from .fields import NodeExtension, node
+from .types import (
GlobalID,
GlobalIDValueError,
- ListConnection,
Node,
NodeID,
- NodeType,
- PageInfo,
)
-from .utils import from_base64, to_base64
+
+_DEPRECATIONS = {
+ "Connection": Connection,
+ "ConnectionExtension": ConnectionExtension,
+ "Edge": Edge,
+ "ListConnection": ListConnection,
+ "NodeType": NodeType,
+ "PageInfo": PageInfo,
+ "connection": connection,
+ "from_base64": from_base64,
+ "to_base64": to_base64,
+}
+
+
+def __getattr__(name: str) -> Any:
+ if name in _DEPRECATIONS:
+ warnings.warn(
+ f"{name} should be imported from strawberry.pagination",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return _DEPRECATIONS[name]
+
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+
__all__ = [
- "Connection",
- "ConnectionExtension",
- "Edge",
"GlobalID",
"GlobalIDValueError",
- "ListConnection",
"Node",
"NodeExtension",
"NodeID",
- "NodeType",
- "PageInfo",
- "connection",
- "from_base64",
"node",
- "to_base64",
]
diff --git a/strawberry/relay/exceptions.py b/strawberry/relay/exceptions.py
index 4150633d8c..b8e3d92122 100644
--- a/strawberry/relay/exceptions.py
+++ b/strawberry/relay/exceptions.py
@@ -1,15 +1,13 @@
from __future__ import annotations
-from collections.abc import Callable
from functools import cached_property
-from typing import TYPE_CHECKING, Optional, Type, cast
+from typing import TYPE_CHECKING, Optional, Type
from strawberry.exceptions.exception import StrawberryException
from strawberry.exceptions.utils.source_finder import SourceFinder
if TYPE_CHECKING:
from strawberry.exceptions.exception_source import ExceptionSource
- from strawberry.types.fields.resolver import StrawberryResolver
class NodeIDAnnotationError(StrawberryException):
@@ -40,72 +38,4 @@ def exception_source(self) -> Optional[ExceptionSource]:
return source_finder.find_class_from_object(self.cls)
-class RelayWrongAnnotationError(StrawberryException):
- def __init__(self, field_name: str, cls: Type) -> None:
- self.cls = cls
- self.field_name = field_name
-
- self.message = (
- f'Wrong annotation used on field "{field_name}". It should be '
- 'annotated with a "Connection" subclass.'
- )
- self.rich_message = (
- f"Wrong annotation for field `[underline]{self.field_name}[/]`"
- )
- self.suggestion = (
- "To fix this error you can add a valid annotation, "
- f"like [italic]`{self.field_name}: relay.Connection[{cls}]` "
- f"or [italic]`@relay.connection(relay.Connection[{cls}])`"
- )
- self.annotation_message = "relay wrong annotation"
-
- super().__init__(self.message)
-
- @cached_property
- def exception_source(self) -> Optional[ExceptionSource]:
- if self.cls is None:
- return None # pragma: no cover
-
- source_finder = SourceFinder()
- return source_finder.find_class_attribute_from_object(self.cls, self.field_name)
-
-
-class RelayWrongResolverAnnotationError(StrawberryException):
- def __init__(self, field_name: str, resolver: StrawberryResolver) -> None:
- self.function = resolver.wrapped_func
- self.field_name = field_name
-
- self.message = (
- f'Wrong annotation used on "{field_name}" resolver. '
- "It should be return an iterable or async iterable object."
- )
- self.rich_message = (
- f"Wrong annotation used on `{field_name}` resolver. "
- "It should be return an `iterable` or `async iterable` object."
- )
- self.suggestion = (
- "To fix this error you can annootate your resolver to return "
- "one of the following options: `List[]`, "
- "`Iterator[]`, `Iterable[]`, "
- "`AsyncIterator[]`, `AsyncIterable[]`, "
- "`Generator[, Any, Any]` and "
- "`AsyncGenerator[, Any]`."
- )
- self.annotation_message = "relay wrong resolver annotation"
-
- super().__init__(self.message)
-
- @cached_property
- def exception_source(self) -> Optional[ExceptionSource]:
- if self.function is None:
- return None # pragma: no cover
-
- source_finder = SourceFinder()
- return source_finder.find_function_from_object(cast(Callable, self.function))
-
-
-__all__ = [
- "NodeIDAnnotationError",
- "RelayWrongAnnotationError",
- "RelayWrongResolverAnnotationError",
-]
+__all__ = ["NodeIDAnnotationError"]
diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py
index 5af00700f8..13fe48cdb1 100644
--- a/strawberry/relay/fields.py
+++ b/strawberry/relay/fields.py
@@ -1,57 +1,39 @@
from __future__ import annotations
import asyncio
-import dataclasses
import inspect
+import warnings
from collections import defaultdict
-from collections.abc import AsyncIterable
from typing import (
TYPE_CHECKING,
Any,
- AsyncIterator,
Awaitable,
Callable,
DefaultDict,
Dict,
- ForwardRef,
- Iterable,
Iterator,
List,
- Mapping,
- Optional,
- Sequence,
Tuple,
Type,
Union,
cast,
- overload,
)
-from typing_extensions import Annotated, get_origin
+from typing_extensions import Annotated
-from strawberry.annotation import StrawberryAnnotation
from strawberry.extensions.field_extension import (
- AsyncExtensionResolver,
FieldExtension,
SyncExtensionResolver,
)
-from strawberry.relay.exceptions import (
- RelayWrongAnnotationError,
- RelayWrongResolverAnnotationError,
-)
-from strawberry.types.arguments import StrawberryArgument, argument
+from strawberry.pagination.fields import ConnectionExtension, connection
+from strawberry.types.arguments import argument # noqa: TCH001
from strawberry.types.base import StrawberryList, StrawberryOptional
-from strawberry.types.field import _RESOLVER_TYPE, StrawberryField, field
+from strawberry.types.field import StrawberryField, field
from strawberry.types.fields.resolver import StrawberryResolver
-from strawberry.types.lazy_type import LazyType
from strawberry.utils.aio import asyncgen_to_list
-from strawberry.utils.typing import eval_type, is_generic_alias
-from .types import Connection, GlobalID, Node, NodeIterableType, NodeType
+from .types import GlobalID, Node
if TYPE_CHECKING:
- from typing_extensions import Literal
-
- from strawberry.permission import BasePermission
from strawberry.types.info import Info
@@ -182,142 +164,6 @@ async def resolve(resolved: Any = resolved_nodes) -> List[Node]:
return resolver
-class ConnectionExtension(FieldExtension):
- connection_type: Type[Connection[Node]]
-
- def apply(self, field: StrawberryField) -> None:
- field.arguments = [
- *field.arguments,
- StrawberryArgument(
- python_name="before",
- graphql_name=None,
- type_annotation=StrawberryAnnotation(Optional[str]),
- description=(
- "Returns the items in the list that come before the "
- "specified cursor."
- ),
- default=None,
- ),
- StrawberryArgument(
- python_name="after",
- graphql_name=None,
- type_annotation=StrawberryAnnotation(Optional[str]),
- description=(
- "Returns the items in the list that come after the "
- "specified cursor."
- ),
- default=None,
- ),
- StrawberryArgument(
- python_name="first",
- graphql_name=None,
- type_annotation=StrawberryAnnotation(Optional[int]),
- description="Returns the first n items from the list.",
- default=None,
- ),
- StrawberryArgument(
- python_name="last",
- graphql_name=None,
- type_annotation=StrawberryAnnotation(Optional[int]),
- description=(
- "Returns the items in the list that come after the "
- "specified cursor."
- ),
- default=None,
- ),
- ]
-
- f_type = field.type
-
- if isinstance(f_type, LazyType):
- f_type = f_type.resolve_type()
- field.type = f_type
-
- type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type
- if not isinstance(type_origin, type) or not issubclass(type_origin, Connection):
- raise RelayWrongAnnotationError(field.name, cast(type, field.origin))
-
- assert field.base_resolver
- # TODO: We are not using resolver_type.type because it will call
- # StrawberryAnnotation.resolve, which will strip async types from the
- # type (i.e. AsyncGenerator[Fruit] will become Fruit). This is done there
- # for subscription support, but we can't use it here. Maybe we can refactor
- # this in the future.
- resolver_type = field.base_resolver.signature.return_annotation
- if isinstance(resolver_type, str):
- resolver_type = ForwardRef(resolver_type)
- if isinstance(resolver_type, ForwardRef):
- resolver_type = eval_type(
- resolver_type,
- field.base_resolver._namespace,
- None,
- )
-
- origin = get_origin(resolver_type)
- if origin is None or not issubclass(
- origin, (Iterator, Iterable, AsyncIterator, AsyncIterable)
- ):
- raise RelayWrongResolverAnnotationError(field.name, field.base_resolver)
-
- self.connection_type = cast(Type[Connection[Node]], field.type)
-
- def resolve(
- self,
- next_: SyncExtensionResolver,
- source: Any,
- info: Info,
- *,
- before: Optional[str] = None,
- after: Optional[str] = None,
- first: Optional[int] = None,
- last: Optional[int] = None,
- **kwargs: Any,
- ) -> Any:
- assert self.connection_type is not None
- return self.connection_type.resolve_connection(
- cast(Iterable[Node], next_(source, info, **kwargs)),
- info=info,
- before=before,
- after=after,
- first=first,
- last=last,
- )
-
- async def resolve_async(
- self,
- next_: AsyncExtensionResolver,
- source: Any,
- info: Info,
- *,
- before: Optional[str] = None,
- after: Optional[str] = None,
- first: Optional[int] = None,
- last: Optional[int] = None,
- **kwargs: Any,
- ) -> Any:
- assert self.connection_type is not None
- nodes = next_(source, info, **kwargs)
- # nodes might be an AsyncIterable/AsyncIterator
- # In this case we don't await for it
- if inspect.isawaitable(nodes):
- nodes = await nodes
-
- resolved = self.connection_type.resolve_connection(
- cast(Iterable[Node], nodes),
- info=info,
- before=before,
- after=after,
- first=first,
- last=last,
- )
-
- # If nodes was an AsyncIterable/AsyncIterator, resolve_connection
- # will return a coroutine which we need to await
- if inspect.isawaitable(resolved):
- resolved = await resolved
- return resolved
-
-
if TYPE_CHECKING:
node = field
else:
@@ -327,155 +173,22 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField:
return field(*args, **kwargs)
-@overload
-def connection(
- graphql_type: Optional[Type[Connection[NodeType]]] = None,
- *,
- resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None,
- name: Optional[str] = None,
- is_subscription: bool = False,
- description: Optional[str] = None,
- init: Literal[True] = True,
- permission_classes: Optional[List[Type[BasePermission]]] = None,
- deprecation_reason: Optional[str] = None,
- default: Any = dataclasses.MISSING,
- default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
- metadata: Optional[Mapping[Any, Any]] = None,
- directives: Optional[Sequence[object]] = (),
- extensions: List[FieldExtension] = (), # type: ignore
-) -> Any: ...
-
-
-@overload
-def connection(
- graphql_type: Optional[Type[Connection[NodeType]]] = None,
- *,
- name: Optional[str] = None,
- is_subscription: bool = False,
- description: Optional[str] = None,
- permission_classes: Optional[List[Type[BasePermission]]] = None,
- deprecation_reason: Optional[str] = None,
- default: Any = dataclasses.MISSING,
- default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
- metadata: Optional[Mapping[Any, Any]] = None,
- directives: Optional[Sequence[object]] = (),
- extensions: List[FieldExtension] = (), # type: ignore
-) -> StrawberryField: ...
-
-
-def connection(
- graphql_type: Optional[Type[Connection[NodeType]]] = None,
- *,
- resolver: Optional[_RESOLVER_TYPE[Any]] = None,
- name: Optional[str] = None,
- is_subscription: bool = False,
- description: Optional[str] = None,
- permission_classes: Optional[List[Type[BasePermission]]] = None,
- deprecation_reason: Optional[str] = None,
- default: Any = dataclasses.MISSING,
- default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
- metadata: Optional[Mapping[Any, Any]] = None,
- directives: Optional[Sequence[object]] = (),
- extensions: List[FieldExtension] = (), # type: ignore
- # This init parameter is used by pyright to determine whether this field
- # is added in the constructor or not. It is not used to change
- # any behavior at the moment.
- init: Literal[True, False, None] = None,
-) -> Any:
- """Annotate a property or a method to create a relay connection field.
-
- Relay connections are mostly used for pagination purposes. This decorator
- helps creating a complete relay endpoint that provides default arguments
- and has a default implementation for the connection slicing.
-
- Note that when setting a resolver to this field, it is expected for this
- resolver to return an iterable of the expected node type, not the connection
- itself. That iterable will then be paginated accordingly. So, the main use
- case for this is to provide a filtered iterable of nodes by using some custom
- filter arguments.
-
- Args:
- graphql_type: The type of the nodes in the connection. This is used to
- determine the type of the edges and the node field in the connection.
- resolver: The resolver for the connection. This is expected to return an
- iterable of the expected node type.
- name: The GraphQL name of the field.
- is_subscription: Whether the field is a subscription.
- description: The GraphQL description of the field.
- permission_classes: The permission classes to apply to the field.
- deprecation_reason: The deprecation reason of the field.
- default: The default value of the field.
- default_factory: The default factory of the field.
- metadata: The metadata of the field.
- directives: The directives to apply to the field.
- extensions: The extensions to apply to the field.
- init: Used only for type checking purposes.
-
- Examples:
- Annotating something like this:
-
- ```python
- @strawberry.type
- class X:
- some_node: relay.Connection[SomeType] = relay.connection(
- resolver=get_some_nodes,
- description="ABC",
+_DEPRECATIONS = {
+ "ConnectionField": ConnectionExtension,
+ "connection": connection,
+}
+
+
+def __getattr__(name: str) -> Any:
+ if name in _DEPRECATIONS:
+ warnings.warn(
+ f"{name} should be imported from strawberry.pagination.fields",
+ DeprecationWarning,
+ stacklevel=2,
)
+ return _DEPRECATIONS[name]
- @relay.connection(relay.Connection[SomeType], description="ABC")
- def get_some_nodes(self, age: int) -> Iterable[SomeType]: ...
- ```
-
- Will produce a query like this:
-
- ```graphql
- query {
- someNode (
- before: String
- after: String
- first: String
- after: String
- age: Int
- ) {
- totalCount
- pageInfo {
- hasNextPage
- hasPreviousPage
- startCursor
- endCursor
- }
- edges {
- cursor
- node {
- id
- ...
- }
- }
- }
- }
- ```
-
- .. _Relay connections:
- https://relay.dev/graphql/connections.htm
-
- """
- f = StrawberryField(
- python_name=None,
- graphql_name=name,
- description=description,
- type_annotation=StrawberryAnnotation.from_annotation(graphql_type),
- is_subscription=is_subscription,
- permission_classes=permission_classes or [],
- deprecation_reason=deprecation_reason,
- default=default,
- default_factory=default_factory,
- metadata=metadata,
- directives=directives or (),
- extensions=[*extensions, ConnectionExtension()],
- )
- if resolver is not None:
- f = f(resolver)
- return f
-
-
-__all__ = ["node", "connection"]
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+
+
+__all__ = ["NodeExtension", "node"]
diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py
index 869e017be9..29e75fb5b2 100644
--- a/strawberry/relay/types.py
+++ b/strawberry/relay/types.py
@@ -2,22 +2,16 @@
import dataclasses
import inspect
-import itertools
import sys
+import warnings
from typing import (
TYPE_CHECKING,
Any,
- AsyncIterable,
- AsyncIterator,
Awaitable,
ClassVar,
ForwardRef,
- Generic,
Iterable,
- Iterator,
- List,
Optional,
- Sequence,
Type,
TypeVar,
Union,
@@ -26,44 +20,37 @@
)
from typing_extensions import Annotated, Literal, Self, TypeAlias, get_args, get_origin
+from strawberry.pagination.types import (
+ PREFIX,
+ Connection,
+ Edge,
+ ListConnection,
+ NodeIterableType,
+ NodeType,
+ PageInfo,
+)
+from strawberry.pagination.utils import (
+ from_base64,
+ to_base64,
+)
from strawberry.relay.exceptions import NodeIDAnnotationError
from strawberry.types.base import (
- StrawberryContainer,
StrawberryObjectDefinition,
- get_object_definition,
)
from strawberry.types.field import field
from strawberry.types.info import Info # noqa: TCH001
from strawberry.types.lazy_type import LazyType
-from strawberry.types.object_type import interface, type
+from strawberry.types.object_type import interface
from strawberry.types.private import StrawberryPrivate
-from strawberry.utils.aio import aenumerate, aislice, resolve_awaitable
-from strawberry.utils.inspect import in_async_context
+from strawberry.utils.aio import resolve_awaitable
from strawberry.utils.typing import eval_type, is_classvar
-from .utils import (
- SliceMetadata,
- from_base64,
- should_resolve_list_connection_edges,
- to_base64,
-)
-
if TYPE_CHECKING:
from strawberry.scalars import ID
from strawberry.utils.await_maybe import AwaitableOrValue
_T = TypeVar("_T")
-NodeIterableType: TypeAlias = Union[
- Iterator[_T],
- Iterable[_T],
- AsyncIterator[_T],
- AsyncIterable[_T],
-]
-NodeType = TypeVar("NodeType", bound="Node")
-
-PREFIX = "arrayconnection"
-
class GlobalIDValueError(ValueError):
"""GlobalID value error, usually related to parsing or serialization."""
@@ -618,320 +605,27 @@ def resolve_node(
return next(iter(cast(Iterable[Self], retval)))
-@type(description="Information to aid in pagination.")
-class PageInfo:
- """Information to aid in pagination.
-
- Attributes:
- has_next_page:
- When paginating forwards, are there more items?
- has_previous_page:
- When paginating backwards, are there more items?
- start_cursor:
- When paginating backwards, the cursor to continue
- end_cursor:
- When paginating forwards, the cursor to continue
- """
-
- has_next_page: bool = field(
- description="When paginating forwards, are there more items?",
- )
- has_previous_page: bool = field(
- description="When paginating backwards, are there more items?",
- )
- start_cursor: Optional[str] = field(
- description="When paginating backwards, the cursor to continue.",
- )
- end_cursor: Optional[str] = field(
- description="When paginating forwards, the cursor to continue.",
- )
-
-
-@type(description="An edge in a connection.")
-class Edge(Generic[NodeType]):
- """An edge in a connection.
-
- Attributes:
- cursor:
- A cursor for use in pagination
- node:
- The item at the end of the edge
- """
-
- cursor: str = field(description="A cursor for use in pagination")
- node: NodeType = field(description="The item at the end of the edge")
-
- @classmethod
- def resolve_edge(cls, node: NodeType, *, cursor: Any = None) -> Self:
- return cls(cursor=to_base64(PREFIX, cursor), node=node)
-
-
-@type(description="A connection to a list of items.")
-class Connection(Generic[NodeType]):
- """A connection to a list of items.
-
- Attributes:
- page_info:
- Pagination data for this connection
- edges:
- Contains the nodes in this connection
-
- """
-
- page_info: PageInfo = field(description="Pagination data for this connection")
- edges: List[Edge[NodeType]] = field(
- description="Contains the nodes in this connection"
- )
-
- @classmethod
- def resolve_node(cls, node: Any, *, info: Info, **kwargs: Any) -> NodeType:
- """The identity function for the node.
-
- This method is used to resolve a node of a different type to the
- connection's `NodeType`.
-
- By default it returns the node itself, but subclasses can override
- this to provide a custom implementation.
-
- Args:
- node:
- The resolved node which should return an instance of this
- connection's `NodeType`.
- info:
- The strawberry execution info resolve the type name from.
- **kwargs:
- Additional arguments passed to the resolver.
-
- """
- return node
-
- @classmethod
- def resolve_connection(
- cls,
- nodes: NodeIterableType[NodeType],
- *,
- info: Info,
- before: Optional[str] = None,
- after: Optional[str] = None,
- first: Optional[int] = None,
- last: Optional[int] = None,
- **kwargs: Any,
- ) -> AwaitableOrValue[Self]:
- """Resolve a connection from nodes.
-
- Subclasses must define this method to paginate nodes based
- on `first`/`last`/`before`/`after` arguments.
-
- Args:
- info: The strawberry execution info resolve the type name from.
- nodes: An iterable/iteretor of nodes to paginate.
- before: Returns the items in the list that come before the specified cursor.
- after: Returns the items in the list that come after the specified cursor.
- first: Returns the first n items from the list.
- last: Returns the items in the list that come after the specified cursor.
- kwargs: Additional arguments passed to the resolver.
-
- Returns:
- The resolved `Connection`
-
- """
- raise NotImplementedError
-
-
-@type(name="Connection", description="A connection to a list of items.")
-class ListConnection(Connection[NodeType]):
- """A connection to a list of items.
-
- Attributes:
- page_info:
- Pagination data for this connection
- edges:
- Contains the nodes in this connection
-
- """
-
- page_info: PageInfo = field(description="Pagination data for this connection")
- edges: List[Edge[NodeType]] = field(
- description="Contains the nodes in this connection"
- )
-
- @classmethod
- def resolve_connection(
- cls,
- nodes: NodeIterableType[NodeType],
- *,
- info: Info,
- before: Optional[str] = None,
- after: Optional[str] = None,
- first: Optional[int] = None,
- last: Optional[int] = None,
- **kwargs: Any,
- ) -> AwaitableOrValue[Self]:
- """Resolve a connection from the list of nodes.
-
- This uses the described Relay Pagination algorithm_
-
- Args:
- info: The strawberry execution info resolve the type name from.
- nodes: An iterable/iteretor of nodes to paginate.
- before: Returns the items in the list that come before the specified cursor.
- after: Returns the items in the list that come after the specified cursor.
- first: Returns the first n items from the list.
- last: Returns the items in the list that come after the specified cursor.
- kwargs: Additional arguments passed to the resolver.
+_DEPRECATIONS = {
+ "Connection": Connection,
+ "Edge": Edge,
+ "ListConnection": ListConnection,
+ "NodeIterableType": NodeIterableType,
+ "NodeType": NodeType,
+ "PREFIX": PREFIX,
+ "PageInfo": PageInfo,
+}
- Returns:
- The resolved `Connection`
- .. _Relay Pagination algorithm:
- https://relay.dev/graphql/connections.htm#sec-Pagination-algorithm
- """
- slice_metadata = SliceMetadata.from_arguments(
- info,
- before=before,
- after=after,
- first=first,
- last=last,
+def __getattr__(name: str) -> Any:
+ if name in _DEPRECATIONS:
+ warnings.warn(
+ f"{name} should be imported from strawberry.pagination.types",
+ DeprecationWarning,
+ stacklevel=2,
)
+ return _DEPRECATIONS[name]
- type_def = get_object_definition(cls)
- assert type_def
- field_def = type_def.get_field("edges")
- assert field_def
-
- field = field_def.resolve_type(type_definition=type_def)
- while isinstance(field, StrawberryContainer):
- field = field.of_type
-
- edge_class = cast(Edge[NodeType], field)
-
- if isinstance(nodes, (AsyncIterator, AsyncIterable)) and in_async_context():
-
- async def resolver() -> Self:
- try:
- iterator = cast(
- Union[AsyncIterator[NodeType], AsyncIterable[NodeType]],
- cast(Sequence, nodes)[
- slice_metadata.start : slice_metadata.overfetch
- ],
- )
- except TypeError:
- # TODO: Why mypy isn't narrowing this based on the if above?
- assert isinstance(nodes, (AsyncIterator, AsyncIterable))
- iterator = aislice(
- nodes,
- slice_metadata.start,
- slice_metadata.overfetch,
- )
-
- # The slice above might return an object that now is not async
- # iterable anymore (e.g. an already cached django queryset)
- if isinstance(iterator, (AsyncIterator, AsyncIterable)):
- edges: List[Edge] = [
- edge_class.resolve_edge(
- cls.resolve_node(v, info=info, **kwargs),
- cursor=slice_metadata.start + i,
- )
- async for i, v in aenumerate(iterator)
- ]
- else:
- edges: List[Edge] = [ # type: ignore[no-redef]
- edge_class.resolve_edge(
- cls.resolve_node(v, info=info, **kwargs),
- cursor=slice_metadata.start + i,
- )
- for i, v in enumerate(iterator)
- ]
-
- has_previous_page = slice_metadata.start > 0
- if (
- slice_metadata.expected is not None
- and len(edges) == slice_metadata.expected + 1
- ):
- # Remove the overfetched result
- edges = edges[:-1]
- has_next_page = True
- elif slice_metadata.end == sys.maxsize:
- # Last was asked without any after/before
- assert last is not None
- original_len = len(edges)
- edges = edges[-last:]
- has_next_page = False
- has_previous_page = len(edges) != original_len
- else:
- has_next_page = False
-
- return cls(
- edges=edges,
- page_info=PageInfo(
- start_cursor=edges[0].cursor if edges else None,
- end_cursor=edges[-1].cursor if edges else None,
- has_previous_page=has_previous_page,
- has_next_page=has_next_page,
- ),
- )
-
- return resolver()
-
- try:
- iterator = cast(
- Union[Iterator[NodeType], Iterable[NodeType]],
- cast(Sequence, nodes)[slice_metadata.start : slice_metadata.overfetch],
- )
- except TypeError:
- assert isinstance(nodes, (Iterable, Iterator))
- iterator = itertools.islice(
- nodes,
- slice_metadata.start,
- slice_metadata.overfetch,
- )
-
- if not should_resolve_list_connection_edges(info):
- return cls(
- edges=[],
- page_info=PageInfo(
- start_cursor=None,
- end_cursor=None,
- has_previous_page=False,
- has_next_page=False,
- ),
- )
-
- edges = [
- edge_class.resolve_edge(
- cls.resolve_node(v, info=info, **kwargs),
- cursor=slice_metadata.start + i,
- )
- for i, v in enumerate(iterator)
- ]
-
- has_previous_page = slice_metadata.start > 0
- if (
- slice_metadata.expected is not None
- and len(edges) == slice_metadata.expected + 1
- ):
- # Remove the overfetched result
- edges = edges[:-1]
- has_next_page = True
- elif slice_metadata.end == sys.maxsize:
- # Last was asked without any after/before
- assert last is not None
- original_len = len(edges)
- edges = edges[-last:]
- has_next_page = False
- has_previous_page = len(edges) != original_len
- else:
- has_next_page = False
-
- return cls(
- edges=edges,
- page_info=PageInfo(
- start_cursor=edges[0].cursor if edges else None,
- end_cursor=edges[-1].cursor if edges else None,
- has_previous_page=has_previous_page,
- has_next_page=has_next_page,
- ),
- )
+ raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = [
@@ -941,11 +635,4 @@ async def resolver() -> Self:
"NodeID",
"NodeIDAnnotationError",
"NodeIDPrivate",
- "NodeIterableType",
- "NodeType",
- "PREFIX",
- "Connection",
- "Edge",
- "PageInfo",
- "ListConnection",
]
diff --git a/strawberry/relay/utils.py b/strawberry/relay/utils.py
index c25bd537b0..1a7a749c9d 100644
--- a/strawberry/relay/utils.py
+++ b/strawberry/relay/utils.py
@@ -1,198 +1,28 @@
-from __future__ import annotations
-
-import base64
-import dataclasses
-import sys
-from typing import TYPE_CHECKING, Any, Tuple, Union
-from typing_extensions import Self, assert_never
-
-from strawberry.types.base import StrawberryObjectDefinition
-from strawberry.types.nodes import InlineFragment, Selection
-
-if TYPE_CHECKING:
- from strawberry.types.info import Info
-
-
-def from_base64(value: str) -> Tuple[str, str]:
- """Parse the base64 encoded relay value.
-
- Args:
- value:
- The value to be parsed
-
- Returns:
- A tuple of (TypeName, NodeID).
-
- Raises:
- ValueError:
- If the value is not in the expected format
-
- """
- try:
- res = base64.b64decode(value.encode()).decode().split(":", 1)
- except Exception as e:
- raise ValueError(str(e)) from e
-
- if len(res) != 2:
- raise ValueError(f"{res} expected to contain only 2 items")
-
- return res[0], res[1]
-
-
-def to_base64(type_: Union[str, type, StrawberryObjectDefinition], node_id: Any) -> str:
- """Encode the type name and node id to a base64 string.
-
- Args:
- type_:
- The GraphQL type, type definition or type name.
- node_id:
- The node id itself
-
- Returns:
- A GlobalID, which is a string resulting from base64 encoding :.
-
- Raises:
- ValueError:
- If the value is not a valid GraphQL type or name
-
- """
- try:
- if isinstance(type_, str):
- type_name = type_
- elif isinstance(type_, StrawberryObjectDefinition):
- type_name = type_.name
- elif isinstance(type_, type):
- type_name = type_.__strawberry_definition__.name # type:ignore
- else: # pragma: no cover
- assert_never(type_)
- except Exception as e:
- raise ValueError(f"{type_} is not a valid GraphQL type or name") from e
-
- return base64.b64encode(f"{type_name}:{node_id}".encode()).decode()
-
-
-def should_resolve_list_connection_edges(info: Info) -> bool:
- """Check if the user requested to resolve the `edges` field of a connection.
-
- Args:
- info:
- The strawberry execution info resolve the type name from
-
- Returns:
- True if the user requested to resolve the `edges` field of a connection, False otherwise.
-
- """
- resolve_for_field_names = {"edges", "pageInfo"}
-
- def _check_selection(selection: Selection) -> bool:
- """Recursively inspect the selection to check if the user requested to resolve the `edges` field.
-
- Args:
- selection (Selection): The selection to check.
-
- Returns:
- bool: True if the user requested to resolve the `edges` field of a connection, False otherwise.
- """
- if (
- not isinstance(selection, InlineFragment)
- and selection.name in resolve_for_field_names
- ):
- return True
- if selection.selections:
- return any(
- _check_selection(selection) for selection in selection.selections
- )
- return False
-
- for selection_field in info.selected_fields:
- for selection in selection_field.selections:
- if _check_selection(selection):
- return True
- return False
-
-
-@dataclasses.dataclass
-class SliceMetadata:
- start: int
- end: int
- expected: int | None
-
- @property
- def overfetch(self) -> int:
- # Overfetch by 1 to check if we have a next result
- return self.end + 1 if self.end != sys.maxsize else self.end
-
- @classmethod
- def from_arguments(
- cls,
- info: Info,
- *,
- before: str | None = None,
- after: str | None = None,
- first: int | None = None,
- last: int | None = None,
- ) -> Self:
- """Get the slice metadata to use on ListConnection."""
- from strawberry.relay.types import PREFIX
-
- max_results = info.schema.config.relay_max_results
- start = 0
- end: int | None = None
-
- if after:
- after_type, after_parsed = from_base64(after)
- if after_type != PREFIX:
- raise TypeError("Argument 'after' contains a non-existing value.")
-
- start = int(after_parsed) + 1
- if before:
- before_type, before_parsed = from_base64(before)
- if before_type != PREFIX:
- raise TypeError("Argument 'before' contains a non-existing value.")
- end = int(before_parsed)
-
- if isinstance(first, int):
- if first < 0:
- raise ValueError("Argument 'first' must be a non-negative integer.")
-
- if first > max_results:
- raise ValueError(
- f"Argument 'first' cannot be higher than {max_results}."
- )
-
- if end is not None:
- start = max(0, end - 1)
-
- end = start + first
- if isinstance(last, int):
- if last < 0:
- raise ValueError("Argument 'last' must be a non-negative integer.")
-
- if last > max_results:
- raise ValueError(
- f"Argument 'last' cannot be higher than {max_results}."
- )
-
- if end is not None:
- start = max(start, end - last)
- else:
- end = sys.maxsize
-
- if end is None:
- end = start + max_results
-
- expected = end - start if end != sys.maxsize else None
-
- return cls(
- start=start,
- end=end,
- expected=expected,
+import warnings
+from typing import Any
+
+from strawberry.pagination.utils import (
+ SliceMetadata,
+ from_base64,
+ should_resolve_list_connection_edges,
+ to_base64,
+)
+
+_DEPRECATIONS = {
+ "SliceMetadata": SliceMetadata,
+ "from_base64": from_base64,
+ "should_resolve_list_connection_edges": should_resolve_list_connection_edges,
+ "to_base64": to_base64,
+}
+
+
+def __getattr__(name: str) -> Any:
+ if name in _DEPRECATIONS:
+ warnings.warn(
+ f"{name} should be imported from strawberry.pagination.utils",
+ DeprecationWarning,
+ stacklevel=2,
)
+ return _DEPRECATIONS[name]
-
-__all__ = [
- "from_base64",
- "to_base64",
- "should_resolve_list_connection_edges",
- "SliceMetadata",
-]
+ raise AttributeError(f"module {__name__} has no attribute {name}")
diff --git a/strawberry/utils.py b/strawberry/utils.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pagination/__init__.py b/tests/pagination/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pagination/test_exceptions.py b/tests/pagination/test_exceptions.py
new file mode 100644
index 0000000000..bdf735dfef
--- /dev/null
+++ b/tests/pagination/test_exceptions.py
@@ -0,0 +1,70 @@
+from typing import List
+
+import pytest
+
+import strawberry
+from strawberry.pagination.exceptions import (
+ ConnectionWrongAnnotationError,
+ ConnectionWrongResolverAnnotationError,
+)
+from strawberry.pagination.fields import connection
+from strawberry.pagination.types import Connection
+
+
+@pytest.mark.raises_strawberry_exception(
+ ConnectionWrongAnnotationError,
+ match=(
+ 'Wrong annotation used on field "fruits_conn". '
+ 'It should be annotated with a "Connection" subclass.'
+ ),
+)
+def test_raises_error_on_connection_missing_annotation():
+ @strawberry.type
+ class Fruit:
+ pk: str
+
+ @strawberry.type
+ class Query:
+ fruits_conn: List[Fruit] = connection()
+
+ strawberry.Schema(query=Query)
+
+
+@pytest.mark.raises_strawberry_exception(
+ ConnectionWrongAnnotationError,
+ match=(
+ 'Wrong annotation used on field "custom_resolver". '
+ 'It should be annotated with a "Connection" subclass.'
+ ),
+)
+def test_raises_error_on_connection_wrong_annotation():
+ @strawberry.type
+ class Fruit:
+ pk: str
+
+ @strawberry.type
+ class Query:
+ @connection(List[Fruit]) # type: ignore
+ def custom_resolver(self) -> List[Fruit]: ...
+
+ strawberry.Schema(query=Query)
+
+
+@pytest.mark.raises_strawberry_exception(
+ ConnectionWrongResolverAnnotationError,
+ match=(
+ 'Wrong annotation used on "custom_resolver" resolver. '
+ "It should be return an iterable or async iterable object."
+ ),
+)
+def test_raises_error_on_connection_resolver_wrong_annotation():
+ @strawberry.type
+ class Fruit:
+ pk: str
+
+ @strawberry.type
+ class Query:
+ @connection(Connection[Fruit]) # type: ignore
+ def custom_resolver(self): ...
+
+ strawberry.Schema(query=Query)
diff --git a/tests/pagination/test_types.py b/tests/pagination/test_types.py
new file mode 100644
index 0000000000..b3e9fd0e92
--- /dev/null
+++ b/tests/pagination/test_types.py
@@ -0,0 +1,145 @@
+from typing import AsyncGenerator, AsyncIterable, List
+
+from pytest_mock import MockerFixture
+
+import strawberry
+from strawberry.pagination.fields import connection
+from strawberry.pagination.types import ListConnection, NodeType
+
+
+async def test_resolve_async_list_connection():
+ @strawberry.type
+ class SomeType:
+ id: int
+
+ @strawberry.type
+ class Query:
+ @connection(ListConnection[SomeType])
+ async def some_type_conn(self) -> AsyncGenerator[SomeType, None]:
+ yield SomeType(id=0)
+ yield SomeType(id=1)
+ yield SomeType(id=2)
+
+ schema = strawberry.Schema(query=Query)
+ ret = await schema.execute(
+ """\
+ query {
+ someTypeConn {
+ edges {
+ node {
+ id
+ }
+ }
+ }
+ }
+ """
+ )
+ assert ret.errors is None
+ assert ret.data == {
+ "someTypeConn": {
+ "edges": [
+ {"node": {"id": 0}},
+ {"node": {"id": 1}},
+ {"node": {"id": 2}},
+ ],
+ }
+ }
+
+
+async def test_resolve_async_list_connection_but_sync_after_sliced():
+ # We are mimicking an object which is async iterable, but when sliced
+ # returns something that is not anymore. This is similar to an already
+ # prefetched django QuerySet, which is async iterable by default, but
+ # when sliced, since it is already prefetched, will return a list.
+ class Slicer:
+ def __init__(self, nodes) -> None:
+ self.nodes = nodes
+
+ async def __aiter__(self):
+ for n in self.nodes:
+ yield n
+
+ def __getitem__(self, key):
+ return self.nodes[key]
+
+ @strawberry.type
+ class SomeType:
+ id: int
+
+ @strawberry.type
+ class Query:
+ @connection(ListConnection[SomeType])
+ async def some_type_conn(self) -> AsyncIterable[SomeType]:
+ return Slicer([SomeType(id=0), SomeType(id=1), SomeType(id=2)])
+
+ schema = strawberry.Schema(query=Query)
+ ret = await schema.execute(
+ """\
+ query {
+ someTypeConn {
+ edges {
+ node {
+ id
+ }
+ }
+ }
+ }
+ """
+ )
+ assert ret.errors is None
+ assert ret.data == {
+ "someTypeConn": {
+ "edges": [
+ {"node": {"id": 0}},
+ {"node": {"id": 1}},
+ {"node": {"id": 2}},
+ ],
+ }
+ }
+
+
+def test_list_connection_without_edges_or_page_info_should_not_call_resolver(
+ mocker: MockerFixture,
+):
+ resolve_edge_mock = mocker.patch("strawberry.pagination.types.Edge.resolve_edge")
+
+ @strawberry.type(name="Connection", description="A connection to a list of items.")
+ class DummyListConnectionWithTotalCount(ListConnection[NodeType]):
+ @strawberry.field(description="Total quantity of existing nodes.")
+ def total_count(self) -> int:
+ return -1
+
+ @strawberry.type
+ class Fruit:
+ id: int
+
+ def fruits_resolver() -> List[Fruit]:
+ return [Fruit(id=1), Fruit(id=2), Fruit(id=3), Fruit(id=4), Fruit(id=5)]
+
+ fruits_resolver_spy = mocker.spy(fruits_resolver, "__call__")
+
+ @strawberry.type
+ class Query:
+ fruits: DummyListConnectionWithTotalCount[Fruit] = connection(
+ resolver=fruits_resolver
+ )
+
+ schema = strawberry.Schema(query=Query)
+ ret = schema.execute_sync(
+ """
+ query {
+ fruits {
+ totalCount
+ }
+ }
+ """
+ )
+ assert ret.errors is None
+ assert ret.data == {
+ "fruits": {
+ "totalCount": -1,
+ }
+ }
+
+ resolve_edge_mock.assert_not_called()
+ fruits_resolver_spy.assert_not_called()
diff --git a/tests/relay/test_utils.py b/tests/pagination/test_utils.py
similarity index 94%
rename from tests/relay/test_utils.py
rename to tests/pagination/test_utils.py
index e14a375522..13eee7402e 100644
--- a/tests/relay/test_utils.py
+++ b/tests/pagination/test_utils.py
@@ -6,8 +6,9 @@
import pytest
-from strawberry.relay.types import PREFIX
-from strawberry.relay.utils import (
+import strawberry
+from strawberry.pagination.types import PREFIX
+from strawberry.pagination.utils import (
SliceMetadata,
from_base64,
to_base64,
@@ -15,8 +16,6 @@
from strawberry.schema.config import StrawberryConfig
from strawberry.types.base import get_object_definition
-from .schema import Fruit
-
def test_from_base64():
type_name, node_id = from_base64("Zm9vYmFyOjE=") # foobar:1
@@ -55,11 +54,19 @@ def test_to_base64():
def test_to_base64_with_type():
+ @strawberry.type
+ class Fruit:
+ id: int
+
value = to_base64(Fruit, "1")
assert value == "RnJ1aXQ6MQ=="
def test_to_base64_with_typedef():
+ @strawberry.type
+ class Fruit:
+ id: int
+
value = to_base64(
get_object_definition(Fruit, strict=True),
"1",
diff --git a/tests/relay/test_exceptions.py b/tests/relay/test_exceptions.py
index 174c3748f7..ad0dfe0ae2 100644
--- a/tests/relay/test_exceptions.py
+++ b/tests/relay/test_exceptions.py
@@ -5,11 +5,7 @@
import strawberry
from strawberry import Info, relay
from strawberry.relay import GlobalID
-from strawberry.relay.exceptions import (
- NodeIDAnnotationError,
- RelayWrongAnnotationError,
- RelayWrongResolverAnnotationError,
-)
+from strawberry.relay.exceptions import NodeIDAnnotationError
@strawberry.type
@@ -96,62 +92,3 @@ class Query:
def fruits(self) -> List[Fruit]: ...
strawberry.Schema(query=Query)
-
-
-@pytest.mark.raises_strawberry_exception(
- RelayWrongAnnotationError,
- match=(
- 'Wrong annotation used on field "fruits_conn". '
- 'It should be annotated with a "Connection" subclass.'
- ),
-)
-def test_raises_error_on_connection_missing_annotation():
- @strawberry.type
- class Fruit(relay.Node):
- pk: relay.NodeID[str]
-
- @strawberry.type
- class Query:
- fruits_conn: List[Fruit] = relay.connection()
-
- strawberry.Schema(query=Query)
-
-
-@pytest.mark.raises_strawberry_exception(
- RelayWrongAnnotationError,
- match=(
- 'Wrong annotation used on field "custom_resolver". '
- 'It should be annotated with a "Connection" subclass.'
- ),
-)
-def test_raises_error_on_connection_wrong_annotation():
- @strawberry.type
- class Fruit(relay.Node):
- pk: relay.NodeID[str]
-
- @strawberry.type
- class Query:
- @relay.connection(List[Fruit]) # type: ignore
- def custom_resolver(self) -> List[Fruit]: ...
-
- strawberry.Schema(query=Query)
-
-
-@pytest.mark.raises_strawberry_exception(
- RelayWrongResolverAnnotationError,
- match=(
- 'Wrong annotation used on "custom_resolver" resolver. '
- "It should be return an iterable or async iterable object."
- ),
-)
-def test_raises_error_on_connection_resolver_wrong_annotation():
- @strawberry.type
- class Fruit(relay.Node):
- pk: relay.NodeID[str]
-
- @strawberry.type
- class Query:
- @relay.connection(relay.Connection[Fruit]) # type: ignore
- def custom_resolver(self): ...
-
- strawberry.Schema(query=Query)
diff --git a/tests/relay/test_types.py b/tests/relay/test_types.py
index 1171d7896a..cb9b4a4408 100644
--- a/tests/relay/test_types.py
+++ b/tests/relay/test_types.py
@@ -1,15 +1,13 @@
-from typing import Any, AsyncGenerator, AsyncIterable, Optional, Union, cast
+from typing import Any, Optional, Union, cast
from typing_extensions import assert_type
-from unittest.mock import MagicMock
import pytest
import strawberry
from strawberry import relay
-from strawberry.relay.utils import to_base64
from strawberry.types.info import Info
-from .schema import Fruit, FruitAsync, fruits_resolver, schema
+from .schema import Fruit, FruitAsync, schema
class FakeInfo:
@@ -149,97 +147,6 @@ class Foo: ...
fruit = await gid.resolve_node(fake_info, ensure_type=Foo)
-async def test_resolve_async_list_connection():
- @strawberry.type
- class SomeType(relay.Node):
- id: relay.NodeID[int]
-
- @strawberry.type
- class Query:
- @relay.connection(relay.ListConnection[SomeType])
- async def some_type_conn(self) -> AsyncGenerator[SomeType, None]:
- yield SomeType(id=0)
- yield SomeType(id=1)
- yield SomeType(id=2)
-
- schema = strawberry.Schema(query=Query)
- ret = await schema.execute(
- """\
- query {
- someTypeConn {
- edges {
- node {
- id
- }
- }
- }
- }
- """
- )
- assert ret.errors is None
- assert ret.data == {
- "someTypeConn": {
- "edges": [
- {"node": {"id": to_base64("SomeType", 0)}},
- {"node": {"id": to_base64("SomeType", 1)}},
- {"node": {"id": to_base64("SomeType", 2)}},
- ],
- }
- }
-
-
-async def test_resolve_async_list_connection_but_sync_after_sliced():
- # We are mimicking an object which is async iterable, but when sliced
- # returns something that is not anymore. This is similar to an already
- # prefetched django QuerySet, which is async iterable by default, but
- # when sliced, since it is already prefetched, will return a list.
- class Slicer:
- def __init__(self, nodes) -> None:
- self.nodes = nodes
-
- async def __aiter__(self):
- for n in self.nodes:
- yield n
-
- def __getitem__(self, key):
- return self.nodes[key]
-
- @strawberry.type
- class SomeType(relay.Node):
- id: relay.NodeID[int]
-
- @strawberry.type
- class Query:
- @relay.connection(relay.ListConnection[SomeType])
- async def some_type_conn(self) -> AsyncIterable[SomeType]:
- return Slicer([SomeType(id=0), SomeType(id=1), SomeType(id=2)])
-
- schema = strawberry.Schema(query=Query)
- ret = await schema.execute(
- """\
- query {
- someTypeConn {
- edges {
- node {
- id
- }
- }
- }
- }
- """
- )
- assert ret.errors is None
- assert ret.data == {
- "someTypeConn": {
- "edges": [
- {"node": {"id": to_base64("SomeType", 0)}},
- {"node": {"id": to_base64("SomeType", 1)}},
- {"node": {"id": to_base64("SomeType", 2)}},
- ],
- }
- }
-
-
def test_overwrite_resolve_id_and_no_node_id():
@strawberry.type
class Fruit(relay.Node):
@@ -258,39 +165,6 @@ def fruit(self) -> Fruit:
strawberry.Schema(query=Query)
-def test_list_connection_without_edges_or_page_info(mocker: MagicMock):
- @strawberry.type(name="Connection", description="A connection to a list of items.")
- class DummyListConnectionWithTotalCount(relay.ListConnection[relay.NodeType]):
- @strawberry.field(description="Total quantity of existing nodes.")
- def total_count(self) -> int:
- return -1
-
- @strawberry.type
- class Query:
- fruits: DummyListConnectionWithTotalCount[Fruit] = relay.connection(
- resolver=fruits_resolver
- )
-
- mock = mocker.patch("strawberry.relay.types.Edge.resolve_edge")
- schema = strawberry.Schema(query=Query)
- ret = schema.execute_sync(
- """
- query {
- fruits {
- totalCount
- }
- }
- """
- )
- mock.assert_not_called()
- assert ret.errors is None
- assert ret.data == {
- "fruits": {
- "totalCount": -1,
- }
- }
-
-
def test_list_connection_with_nested_fragments():
ret = schema.execute_sync(
"""