diff --git a/CHANGELOG.md b/CHANGELOG.md index 21e536fa7e..0f92a7dd11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -480,31 +480,34 @@ An example of migrating existing code is given below: # Existing code @strawberry.type class MyDataType: - name: str + name: str + @strawberry.type class Subscription: - @strawberry.subscription - async def my_data_subscription( - self, info: Info, groups: list[str] - ) -> AsyncGenerator[MyDataType | None, None]: - yield None - async for message in info.context["ws"].channel_listen("my_data", groups=groups): - yield MyDataType(name=message["payload"]) + @strawberry.subscription + async def my_data_subscription( + self, info: Info, groups: list[str] + ) -> AsyncGenerator[MyDataType | None, None]: + yield None + async for message in info.context["ws"].channel_listen( + "my_data", groups=groups + ): + yield MyDataType(name=message["payload"]) ``` ```py # New code @strawberry.type class Subscription: - @strawberry.subscription - async def my_data_subscription( - self, info: Info, groups: list[str] - ) -> AsyncGenerator[MyDataType | None, None]: - async with info.context["ws"].listen_to_channel("my_data", groups=groups) as cm: - yield None - async for message in cm: - yield MyDataType(name=message["payload"]) + @strawberry.subscription + async def my_data_subscription( + self, info: Info, groups: list[str] + ) -> AsyncGenerator[MyDataType | None, None]: + async with info.context["ws"].listen_to_channel("my_data", groups=groups) as cm: + yield None + async for message in cm: + yield MyDataType(name=message["payload"]) ``` Contributed by [Moritz Ulmer](https://github.com/moritz89) via [PR #2856](https://github.com/strawberry-graphql/strawberry/pull/2856/) @@ -1123,6 +1126,7 @@ class Point: x: float y: float + class GetPointsResult: circle_points: List[Point] square_points: List[Point] diff --git a/docs/_test.md b/docs/_test.md index 3aec4a1871..506d3759a6 100644 --- a/docs/_test.md +++ b/docs/_test.md @@ -16,6 +16,7 @@ Code blocks now support: ```python highlight=strawberry,str import strawberry + @strawberry.type class X: name: str @@ -26,6 +27,7 @@ class X: ```python lines=1-4 import strawberry + @strawberry.type class X: name: str diff --git a/docs/guides/pagination/connections.md b/docs/guides/pagination/connections.md index d6f526ca53..3ea01492fa 100644 --- a/docs/guides/pagination/connections.md +++ b/docs/guides/pagination/connections.md @@ -86,13 +86,12 @@ GenericType = TypeVar("GenericType") @strawberry.type class Connection(Generic[GenericType]): page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." + description="Information to aid in pagination." ) edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." + description="A list of edges in this connection." ) - ``` Connections must have atleast two fields: `edges` and `page_info`. @@ -114,32 +113,31 @@ GenericType = TypeVar("GenericType") @strawberry.type class Connection(Generic[GenericType]): page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." + description="Information to aid in pagination." ) edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." + description="A list of edges in this connection." ) @strawberry.type class PageInfo: has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" + description="When paginating forwards, are there more items?" ) has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" + description="When paginating backwards, are there more items?" ) start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." + description="When paginating backwards, the cursor to continue." ) end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." + description="When paginating forwards, the cursor to continue." ) - ``` You can read more about the `PageInfo` type at: @@ -166,43 +164,38 @@ GenericType = TypeVar("GenericType") @strawberry.type class Connection(Generic[GenericType]): page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." + description="Information to aid in pagination." ) edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." + description="A list of edges in this connection." ) @strawberry.type class PageInfo: has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" + description="When paginating forwards, are there more items?" ) has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" + description="When paginating backwards, are there more items?" ) start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." + description="When paginating backwards, the cursor to continue." ) end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." + description="When paginating forwards, the cursor to continue." ) @strawberry.type class Edge(Generic[GenericType]): - node: GenericType = strawberry.field( - description="The item at the end of the edge." - ) - - cursor: str = strawberry.field( - description="A cursor for use in pagination." - ) + node: GenericType = strawberry.field(description="The item at the end of the edge.") + cursor: str = strawberry.field(description="A cursor for use in pagination.") ``` EdgeTypes must have atleast two fields - `cursor` and `node`. @@ -219,30 +212,30 @@ from typing import Generic, TypeVar import strawberry user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42 - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20 - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19 - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20 - } + { + "id": 1, + "name": "Norman Osborn", + "occupation": "Founder, Oscorp Industries", + "age": 42, + }, + { + "id": 2, + "name": "Peter Parker", + "occupation": "Freelance Photographer, The Daily Bugle", + "age": 20, + }, + { + "id": 3, + "name": "Harold Osborn", + "occupation": "President, Oscorp Industries", + "age": 19, + }, + { + "id": 4, + "name": "Eddie Brock", + "occupation": "Journalist, The Eddie Brock Report", + "age": 20, + }, ] @@ -252,43 +245,38 @@ GenericType = TypeVar("GenericType") @strawberry.type class Connection(Generic[GenericType]): page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." + description="Information to aid in pagination." ) edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." + description="A list of edges in this connection." ) @strawberry.type class PageInfo: has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" + description="When paginating forwards, are there more items?" ) has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" + description="When paginating backwards, are there more items?" ) start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." + description="When paginating backwards, the cursor to continue." ) end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." + description="When paginating forwards, the cursor to continue." ) @strawberry.type class Edge(Generic[GenericType]): - node: GenericType = strawberry.field( - description="The item at the end of the edge." - ) - - cursor: str = strawberry.field( - description="A cursor for use in pagination." - ) + node: GenericType = strawberry.field(description="The item at the end of the edge.") + cursor: str = strawberry.field(description="A cursor for use in pagination.") ``` Now is a good time to think of what we could use as a cursor for our dataset. Our cursor @@ -314,53 +302,54 @@ from typing import Generic, TypeVar import strawberry user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42 - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20 - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19 - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20 - } + { + "id": 1, + "name": "Norman Osborn", + "occupation": "Founder, Oscorp Industries", + "age": 42, + }, + { + "id": 2, + "name": "Peter Parker", + "occupation": "Freelance Photographer, The Daily Bugle", + "age": 20, + }, + { + "id": 3, + "name": "Harold Osborn", + "occupation": "President, Oscorp Industries", + "age": 19, + }, + { + "id": 4, + "name": "Eddie Brock", + "occupation": "Journalist, The Eddie Brock Report", + "age": 20, + }, ] + def encode_user_cursor(id: int) -> str: - """ - Encodes the given user ID into a cursor. + """ + Encodes the given user ID into a cursor. - :param id: The user ID to encode. + :param id: The user ID to encode. - :return: The encoded cursor. - """ - return b64encode(f"user:{id}".encode("ascii")).decode("ascii") + :return: The encoded cursor. + """ + return b64encode(f"user:{id}".encode("ascii")).decode("ascii") def decode_user_cursor(cursor: str) -> int: - """ - Decodes the user ID from the given cursor. + """ + Decodes the user ID from the given cursor. - :param cursor: The cursor to decode. + :param cursor: The cursor to decode. - :return: The decoded user ID. - """ - cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") - return int(cursor_data.split(":")[1]) + :return: The decoded user ID. + """ + cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") + return int(cursor_data.split(":")[1]) GenericType = TypeVar("GenericType") @@ -369,43 +358,38 @@ GenericType = TypeVar("GenericType") @strawberry.type class Connection(Generic[GenericType]): page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." + description="Information to aid in pagination." ) edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." + description="A list of edges in this connection." ) @strawberry.type class PageInfo: has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" + description="When paginating forwards, are there more items?" ) has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" + description="When paginating backwards, are there more items?" ) start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." + description="When paginating backwards, the cursor to continue." ) end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." + description="When paginating forwards, the cursor to continue." ) @strawberry.type class Edge(Generic[GenericType]): - node: GenericType = strawberry.field( - description="The item at the end of the edge." - ) - - cursor: str = strawberry.field( - description="A cursor for use in pagination." - ) + node: GenericType = strawberry.field(description="The item at the end of the edge.") + cursor: str = strawberry.field(description="A cursor for use in pagination.") ``` Let us define a `get_users` field which returns a connection of users, as well as an `UserType`. @@ -420,53 +404,54 @@ from typing import List, Optional, Generic, TypeVar import strawberry user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42 - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20 - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19 - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20 - } + { + "id": 1, + "name": "Norman Osborn", + "occupation": "Founder, Oscorp Industries", + "age": 42, + }, + { + "id": 2, + "name": "Peter Parker", + "occupation": "Freelance Photographer, The Daily Bugle", + "age": 20, + }, + { + "id": 3, + "name": "Harold Osborn", + "occupation": "President, Oscorp Industries", + "age": 19, + }, + { + "id": 4, + "name": "Eddie Brock", + "occupation": "Journalist, The Eddie Brock Report", + "age": 20, + }, ] + def encode_user_cursor(id: int) -> str: - """ - Encodes the given user ID into a cursor. + """ + Encodes the given user ID into a cursor. - :param id: The user ID to encode. + :param id: The user ID to encode. - :return: The encoded cursor. - """ - return b64encode(f"user:{id}".encode("ascii")).decode("ascii") + :return: The encoded cursor. + """ + return b64encode(f"user:{id}".encode("ascii")).decode("ascii") def decode_user_cursor(cursor: str) -> int: - """ - Decodes the user ID from the given cursor. + """ + Decodes the user ID from the given cursor. - :param cursor: The cursor to decode. + :param cursor: The cursor to decode. - :return: The decoded user ID. - """ - cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") - return int(cursor_data.split(":")[1]) + :return: The decoded user ID. + """ + cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") + return int(cursor_data.split(":")[1]) GenericType = TypeVar("GenericType") @@ -475,87 +460,80 @@ GenericType = TypeVar("GenericType") @strawberry.type class Connection(Generic[GenericType]): page_info: "PageInfo" = strawberry.field( - description="Information to aid in pagination." + description="Information to aid in pagination." ) edges: list["Edge[GenericType]"] = strawberry.field( - description="A list of edges in this connection." + description="A list of edges in this connection." ) @strawberry.type class PageInfo: has_next_page: bool = strawberry.field( - description="When paginating forwards, are there more items?" + description="When paginating forwards, are there more items?" ) has_previous_page: bool = strawberry.field( - description="When paginating backwards, are there more items?" + description="When paginating backwards, are there more items?" ) start_cursor: Optional[str] = strawberry.field( - description="When paginating backwards, the cursor to continue." + description="When paginating backwards, the cursor to continue." ) end_cursor: Optional[str] = strawberry.field( - description="When paginating forwards, the cursor to continue." + description="When paginating forwards, the cursor to continue." ) @strawberry.type class Edge(Generic[GenericType]): - node: GenericType = strawberry.field( - description="The item at the end of the edge." - ) + node: GenericType = strawberry.field(description="The item at the end of the edge.") + + cursor: str = strawberry.field(description="A cursor for use in pagination.") - cursor: str = strawberry.field( - description="A cursor for use in pagination." - ) @strawberry.type class User: - name: str = strawberry.field( - description="The name of the user." - ) + name: str = strawberry.field(description="The name of the user.") - occupation: str = strawberry.field( - description="The occupation of the user." - ) + occupation: str = strawberry.field(description="The occupation of the user.") - age: int = strawberry.field( - description="The age of the user." - ) + age: int = strawberry.field(description="The age of the user.") @strawberry.type class Query: @strawberry.field(description="Get a list of users.") - def get_users(self, first: int = 2, after: Optional[str] = None) -> Connection[User]: + def get_users( + self, first: int = 2, after: Optional[str] = None + ) -> Connection[User]: if after is not None: - # decode the user ID from the given cursor. - user_id = decode_user_cursor(cursor=after) + # decode the user ID from the given cursor. + user_id = decode_user_cursor(cursor=after) else: - # no cursor was given (this happens usually when the - # client sends a query for the first time). - user_id = 0 + # no cursor was given (this happens usually when the + # client sends a query for the first time). + user_id = 0 # filter the user data, going through the next set of results. filtered_data = map(lambda user: user.id > user_id, user_data) # slice the relevant user data (Here, we also slice an # additional user instance, to prepare the next cursor). - sliced_users = filtered_data[after:first+1] + sliced_users = filtered_data[after : first + 1] if len(sliced_users) > first: - # calculate the client's next cursor. - last_user = sliced_users.pop(-1) - next_cursor = encode_user_cursor(id=last_user.id) - has_next_page = True + # calculate the client's next cursor. + last_user = sliced_users.pop(-1) + next_cursor = encode_user_cursor(id=last_user.id) + has_next_page = True else: - # We have reached the last page, and - # don't have the next cursor. - next_cursor = None - has_next_page = False + # We have reached the last page, and + # don't have the next cursor. + next_cursor = None + has_next_page = False # We know that we have items in the # previous page window if the initial user ID @@ -564,28 +542,28 @@ class Query: # build user edges. edges = [ - Edge( - node=cast(UserType, user), - cursor=encode_user_cursor(id=user.id), - ) - for user in sliced_users + Edge( + node=cast(UserType, user), + cursor=encode_user_cursor(id=user.id), + ) + for user in sliced_users ] if edges: - # we have atleast one edge. Get the cursor - # of the first edge we have. - start_cursor = edges[0].cursor + # we have atleast one edge. Get the cursor + # of the first edge we have. + start_cursor = edges[0].cursor else: - # We have no edges to work with. - start_cursor = None + # We have no edges to work with. + start_cursor = None if len(edges) > 1: - # We have atleast 2 edges. Get the cursor - # of the last edge we have. - end_cursor = edges[-1].cursor + # We have atleast 2 edges. Get the cursor + # of the last edge we have. + end_cursor = edges[-1].cursor else: - # We don't have enough edges to work with. - end_cursor = None + # We don't have enough edges to work with. + end_cursor = None return Connection( edges=edges, @@ -594,11 +572,11 @@ class Query: has_previous_page=has_previous_page, start_cursor=start_cursor, end_cursor=end_cursor, - ) + ), ) -schema = strawberry.Schema(query=Query) +schema = strawberry.Schema(query=Query) ``` you can start the debug server with the following command: diff --git a/docs/guides/pagination/cursor-based.md b/docs/guides/pagination/cursor-based.md index 9a131ed39b..43ebf07e2c 100644 --- a/docs/guides/pagination/cursor-based.md +++ b/docs/guides/pagination/cursor-based.md @@ -65,31 +65,18 @@ import strawberry @strawberry.type class User: - id: str = strawberry.field( - description="ID of the user." - ) - - name: str = strawberry.field( - description="The name of the user." - ) + id: str = strawberry.field(description="ID of the user.") + name: str = strawberry.field(description="The name of the user.") - occupation: str = strawberry.field( - description="The occupation of the user." - ) - + occupation: str = strawberry.field(description="The occupation of the user.") - age: int = strawberry.field( - description="The age of the user." - ) + age: int = strawberry.field(description="The age of the user.") @staticmethod def from_row(row: Dict[str, Any]) -> "User": return User( - id=row['id'], - name=row['name'], - occupation=row['occupation'], - age=row['age'] + id=row["id"], name=row["name"], occupation=row["occupation"], age=row["age"] ) @@ -102,13 +89,9 @@ class PageMeta: @strawberry.type class UserResponse: - users: List[User] = strawberry.field( - description="The list of users." - ) + users: List[User] = strawberry.field(description="The list of users.") - page_meta: PageMeta = strawberry.field( - description="Metadata to aid in pagination." - ) + page_meta: PageMeta = strawberry.field(description="Metadata to aid in pagination.") @strawberry.type @@ -117,8 +100,8 @@ class Query: def get_users(self) -> UserResponse: ... -schema = strawberry.Schema(query=Query) +schema = strawberry.Schema(query=Query) ``` For simplicity's sake, our dataset is going to be an in-memory list. @@ -131,60 +114,47 @@ from typing import List, Optional, Dict, Any, cast import strawberry user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42 - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20 - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19 - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20 - } + { + "id": 1, + "name": "Norman Osborn", + "occupation": "Founder, Oscorp Industries", + "age": 42, + }, + { + "id": 2, + "name": "Peter Parker", + "occupation": "Freelance Photographer, The Daily Bugle", + "age": 20, + }, + { + "id": 3, + "name": "Harold Osborn", + "occupation": "President, Oscorp Industries", + "age": 19, + }, + { + "id": 4, + "name": "Eddie Brock", + "occupation": "Journalist, The Eddie Brock Report", + "age": 20, + }, ] @strawberry.type class User: - id: str = strawberry.field( - description="ID of the user." - ) - - name: str = strawberry.field( - description="The name of the user." - ) + id: str = strawberry.field(description="ID of the user.") + name: str = strawberry.field(description="The name of the user.") - occupation: str = strawberry.field( - description="The occupation of the user." - ) - + occupation: str = strawberry.field(description="The occupation of the user.") - age: int = strawberry.field( - description="The age of the user." - ) + age: int = strawberry.field(description="The age of the user.") @staticmethod def from_row(row: Dict[str, Any]) -> "User": return User( - id=row['id'], - name=row['name'], - occupation=row['occupation'], - age=row['age'] + id=row["id"], name=row["name"], occupation=row["occupation"], age=row["age"] ) @@ -197,13 +167,9 @@ class PageMeta: @strawberry.type class UserResponse: - users: List[User] = strawberry.field( - description="The list of users." - ) + users: List[User] = strawberry.field(description="The list of users.") - page_meta: PageMeta = strawberry.field( - description="Metadata to aid in pagination." - ) + page_meta: PageMeta = strawberry.field(description="Metadata to aid in pagination.") @strawberry.type @@ -212,8 +178,8 @@ class Query: def get_users(self) -> UserResponse: ... -schema = strawberry.Schema(query=Query) +schema = strawberry.Schema(query=Query) ``` Now is a good time to think of what we could use as a cursor for our dataset. Our cursor needs to be an opaque value, @@ -237,80 +203,70 @@ from typing import List, Optional, Dict, Any, cast import strawberry user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42 - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20 - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19 - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20 - } + { + "id": 1, + "name": "Norman Osborn", + "occupation": "Founder, Oscorp Industries", + "age": 42, + }, + { + "id": 2, + "name": "Peter Parker", + "occupation": "Freelance Photographer, The Daily Bugle", + "age": 20, + }, + { + "id": 3, + "name": "Harold Osborn", + "occupation": "President, Oscorp Industries", + "age": 19, + }, + { + "id": 4, + "name": "Eddie Brock", + "occupation": "Journalist, The Eddie Brock Report", + "age": 20, + }, ] + def encode_user_cursor(id: int) -> str: - """ - Encodes the given user ID into a cursor. + """ + Encodes the given user ID into a cursor. - :param id: The user ID to encode. + :param id: The user ID to encode. - :return: The encoded cursor. - """ - return b64encode(f"user:{id}".encode("ascii")).decode("ascii") + :return: The encoded cursor. + """ + return b64encode(f"user:{id}".encode("ascii")).decode("ascii") def decode_user_cursor(cursor: str) -> int: - """ - Decodes the user ID from the given cursor. + """ + Decodes the user ID from the given cursor. - :param cursor: The cursor to decode. + :param cursor: The cursor to decode. - :return: The decoded user ID. - """ - cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") - return int(cursor_data.split(":")[1]) + :return: The decoded user ID. + """ + cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") + return int(cursor_data.split(":")[1]) @strawberry.type class User: - id: str = strawberry.field( - description="ID of the user." - ) + id: str = strawberry.field(description="ID of the user.") - name: str = strawberry.field( - description="The name of the user." - ) + name: str = strawberry.field(description="The name of the user.") - occupation: str = strawberry.field( - description="The occupation of the user." - ) + occupation: str = strawberry.field(description="The occupation of the user.") - age: int = strawberry.field( - description="The age of the user." - ) + age: int = strawberry.field(description="The age of the user.") @staticmethod def from_row(row: Dict[str, Any]) -> "User": return User( - id=row['id'], - name=row['name'], - occupation=row['occupation'], - age=row['age'] + id=row["id"], name=row["name"], occupation=row["occupation"], age=row["age"] ) @@ -323,13 +279,9 @@ class PageMeta: @strawberry.type class UserResponse: - users: List[User] = strawberry.field( - description="The list of users." - ) + users: List[User] = strawberry.field(description="The list of users.") - page_meta: PageMeta = strawberry.field( - description="Metadata to aid in pagination." - ) + page_meta: PageMeta = strawberry.field(description="Metadata to aid in pagination.") @strawberry.type @@ -338,8 +290,8 @@ class Query: def get_users(self) -> UserResponse: ... -schema = strawberry.Schema(query=Query) +schema = strawberry.Schema(query=Query) ``` We're going to use the dataset we defined in our `get_users` field resolver. @@ -357,81 +309,70 @@ from typing import List, Optional, Dict, Any, cast import strawberry user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42 - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20 - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19 - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20 - } + { + "id": 1, + "name": "Norman Osborn", + "occupation": "Founder, Oscorp Industries", + "age": 42, + }, + { + "id": 2, + "name": "Peter Parker", + "occupation": "Freelance Photographer, The Daily Bugle", + "age": 20, + }, + { + "id": 3, + "name": "Harold Osborn", + "occupation": "President, Oscorp Industries", + "age": 19, + }, + { + "id": 4, + "name": "Eddie Brock", + "occupation": "Journalist, The Eddie Brock Report", + "age": 20, + }, ] + def encode_user_cursor(id: int) -> str: - """ - Encodes the given user ID into a cursor. + """ + Encodes the given user ID into a cursor. - :param id: The user ID to encode. + :param id: The user ID to encode. - :return: The encoded cursor. - """ - return b64encode(f"user:{id}".encode("ascii")).decode("ascii") + :return: The encoded cursor. + """ + return b64encode(f"user:{id}".encode("ascii")).decode("ascii") def decode_user_cursor(cursor: str) -> int: - """ - Decodes the user ID from the given cursor. + """ + Decodes the user ID from the given cursor. - :param cursor: The cursor to decode. + :param cursor: The cursor to decode. - :return: The decoded user ID. - """ - cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") - return int(cursor_data.split(":")[1]) + :return: The decoded user ID. + """ + cursor_data = b64decode(cursor.encode("ascii")).decode("ascii") + return int(cursor_data.split(":")[1]) @strawberry.type class User: + id: str = strawberry.field(description="ID of the user.") - id: str = strawberry.field( - description="ID of the user." - ) + name: str = strawberry.field(description="The name of the user.") - name: str = strawberry.field( - description="The name of the user." - ) - - occupation: str = strawberry.field( - description="The occupation of the user." - ) + occupation: str = strawberry.field(description="The occupation of the user.") - age: int = strawberry.field( - description="The age of the user." - ) + age: int = strawberry.field(description="The age of the user.") @staticmethod def from_row(row: Dict[str, Any]) -> "User": return User( - id=row['id'], - name=row['name'], - occupation=row['occupation'], - age=row['age'] + id=row["id"], name=row["name"], occupation=row["occupation"], age=row["age"] ) @@ -444,13 +385,9 @@ class PageMeta: @strawberry.type class UserResponse: - users: List[User] = strawberry.field( - description="The list of users." - ) + users: List[User] = strawberry.field(description="The list of users.") - page_meta: PageMeta = strawberry.field( - description="Metadata to aid in pagination." - ) + page_meta: PageMeta = strawberry.field(description="Metadata to aid in pagination.") @strawberry.type @@ -458,40 +395,37 @@ class Query: @strawberry.field(description="Get a list of users.") def get_users(self, limit: int, cursor: Optional[str] = None) -> UserResponse: if cursor is not None: - # decode the user ID from the given cursor. - user_id = decode_user_cursor(cursor=cursor) + # decode the user ID from the given cursor. + user_id = decode_user_cursor(cursor=cursor) else: - # no cursor was given (this happens usually when the - # client sends a query for the first time). - user_id = 0 + # no cursor was given (this happens usually when the + # client sends a query for the first time). + user_id = 0 # filter the user data, going through the next set of results. - filtered_data = [user for user in user_data if user['id'] >= user_id] + filtered_data = [user for user in user_data if user["id"] >= user_id] # slice the relevant user data (Here, we also slice an # additional user instance, to prepare the next cursor). - sliced_users = filtered_data[:limit+1] + sliced_users = filtered_data[: limit + 1] if len(sliced_users) > limit: - # calculate the client's next cursor. - last_user = sliced_users.pop(-1) - next_cursor = encode_user_cursor(id=last_user['id']) + # calculate the client's next cursor. + last_user = sliced_users.pop(-1) + next_cursor = encode_user_cursor(id=last_user["id"]) else: - # We have reached the last page, and - # don't have the next cursor. - next_cursor = None + # We have reached the last page, and + # don't have the next cursor. + next_cursor = None sliced_users = [User.from_row(x) for x in sliced_users] return UserResponse( - users=sliced_users, - page_meta=PageMeta( - next_cursor=next_cursor - ) + users=sliced_users, page_meta=PageMeta(next_cursor=next_cursor) ) -schema = strawberry.Schema(query=Query) +schema = strawberry.Schema(query=Query) ``` diff --git a/docs/guides/pagination/offset-based.md b/docs/guides/pagination/offset-based.md index 015dfd8644..7985e12dd4 100644 --- a/docs/guides/pagination/offset-based.md +++ b/docs/guides/pagination/offset-based.md @@ -19,23 +19,13 @@ import strawberry @strawberry.type class User: - name: str = strawberry.field( - description="The name of the user." - ) - occupation: str = strawberry.field( - description="The occupation of the user." - ) - age: int = strawberry.field( - description="The age of the user." - ) + name: str = strawberry.field(description="The name of the user.") + occupation: str = strawberry.field(description="The occupation of the user.") + age: int = strawberry.field(description="The age of the user.") @staticmethod def from_row(row: Dict[str, Any]): - return User( - name=row['name'], - occupation=row['occupation'], - age=row['age'] - ) + return User(name=row["name"], occupation=row["occupation"], age=row["age"]) ``` Let us now model the `PaginationWindow`, which represents one "slice" of sorted, filtered, and paginated items. @@ -67,21 +57,21 @@ Let's define the query: @strawberry.type class Query: @strawberry.field(description="Get a list of users.") - def users(self, - order_by: str, - limit: int, - offset: int = 0, - name: str | None = None, - occupation: str| None = None - ) -> PaginationWindow[User]: - + def users( + self, + order_by: str, + limit: int, + offset: int = 0, + name: str | None = None, + occupation: str | None = None, + ) -> PaginationWindow[User]: filters = {} if name: - filters['name'] = name + filters["name"] = name if occupation: - filters['occupation'] = occupation + filters["occupation"] = occupation return get_pagination_window( dataset=user_data, @@ -89,9 +79,10 @@ class Query: order_by=order_by, limit=limit, offset=offset, - filters=filters + filters=filters, ) + schema = strawberry.Schema(query=Query) ``` @@ -101,30 +92,30 @@ For the sake of simplicity, our dataset will be an in-memory list containing fou ```py line=72-97 user_data = [ - { - "id": 1, - "name": "Norman Osborn", - "occupation": "Founder, Oscorp Industries", - "age": 42 - }, - { - "id": 2, - "name": "Peter Parker", - "occupation": "Freelance Photographer, The Daily Bugle", - "age": 20 - }, - { - "id": 3, - "name": "Harold Osborn", - "occupation": "President, Oscorp Industries", - "age": 19 - }, - { - "id": 4, - "name": "Eddie Brock", - "occupation": "Journalist, The Eddie Brock Report", - "age": 20 - } + { + "id": 1, + "name": "Norman Osborn", + "occupation": "Founder, Oscorp Industries", + "age": 42, + }, + { + "id": 2, + "name": "Peter Parker", + "occupation": "Freelance Photographer, The Daily Bugle", + "age": 20, + }, + { + "id": 3, + "name": "Harold Osborn", + "occupation": "President, Oscorp Industries", + "age": 19, + }, + { + "id": 4, + "name": "Eddie Brock", + "occupation": "Journalist, The Eddie Brock Report", + "age": 20, + }, ] ``` @@ -133,12 +124,13 @@ not only for the `User` type. ```py line=100-146 def get_pagination_window( - dataset: List[GenericType], - ItemType: type, - order_by: str, - limit: int, - offset: int = 0, - filters: dict[str, str] = {}) -> PaginationWindow: + dataset: List[GenericType], + ItemType: type, + order_by: str, + limit: int, + offset: int = 0, + filters: dict[str, str] = {}, +) -> PaginationWindow: """ Get one pagination window on the given dataset for the given limit and offset, ordered by the given attribute and filtered using the @@ -146,7 +138,7 @@ def get_pagination_window( """ if limit <= 0 or limit > 100: - raise Exception(f'limit ({limit}) must be between 0-100') + raise Exception(f"limit ({limit}) must be between 0-100") if filters: dataset = list(filter(lambda x: matches(x, filters), dataset)) @@ -154,19 +146,15 @@ def get_pagination_window( dataset.sort(key=lambda x: x[order_by]) if offset != 0 and not 0 <= offset < len(dataset): - raise Exception(f'offset ({offset}) is out of range ' - f'(0-{len(dataset) - 1})') + raise Exception(f"offset ({offset}) is out of range " f"(0-{len(dataset) - 1})") total_items_count = len(dataset) - items = dataset[offset:offset + limit] + items = dataset[offset : offset + limit] items = [ItemType.from_row(x) for x in items] - return PaginationWindow( - items=items, - total_items_count=total_items_count - ) + return PaginationWindow(items=items, total_items_count=total_items_count) def matches(item, filters): diff --git a/strawberry/channels/handlers/base.py b/strawberry/channels/handlers/base.py index 1c91927ce0..1dd09ea7c9 100644 --- a/strawberry/channels/handlers/base.py +++ b/strawberry/channels/handlers/base.py @@ -138,7 +138,7 @@ async def channel_listen( awaitable = asyncio.wait_for(awaitable, timeout) try: yield await awaitable - except asyncio.TimeoutError: # noqa: PERF203 + except asyncio.TimeoutError: # TODO: shall we add log here and maybe in the suppress below? return finally: @@ -215,7 +215,7 @@ async def _listen_to_channel_generator( awaitable = asyncio.wait_for(awaitable, timeout) try: yield await awaitable - except asyncio.TimeoutError: # noqa: PERF203 + except asyncio.TimeoutError: # TODO: shall we add log here and maybe in the suppress below? return diff --git a/strawberry/cli/commands/upgrade/_fake_progress.py b/strawberry/cli/commands/upgrade/_fake_progress.py index 40d92da5a0..6aacc36179 100644 --- a/strawberry/cli/commands/upgrade/_fake_progress.py +++ b/strawberry/cli/commands/upgrade/_fake_progress.py @@ -17,5 +17,5 @@ def add_task(self, *args: Any, **kwargs: Any) -> TaskID: def __enter__(self) -> "FakeProgress": return self - def __exit__(self, *args: Any, **kwargs: Any) -> None: + def __exit__(self, *args: object, **kwargs: Any) -> None: pass diff --git a/strawberry/federation/schema.py b/strawberry/federation/schema.py index c7b10e4e2a..a0a29ca4ad 100644 --- a/strawberry/federation/schema.py +++ b/strawberry/federation/schema.py @@ -209,7 +209,7 @@ def entities_resolver( try: result = get_result() - except Exception as e: # noqa: PERF203 + except Exception as e: result = GraphQLError( f"Unable to resolve reference for {definition.origin}", original_error=e, diff --git a/tests/fastapi/test_context.py b/tests/fastapi/test_context.py index 7b668e0096..ffd7ad64f9 100644 --- a/tests/fastapi/test_context.py +++ b/tests/fastapi/test_context.py @@ -115,7 +115,7 @@ class Query: @strawberry.field def abc(self, info: Info[Any, None]) -> str: assert info.context.get("request") is not None - assert "connection_params" not in info.context.keys() + assert "connection_params" not in info.context assert info.context.get("strawberry") == "rocks" return "abc" diff --git a/tests/http/conftest.py b/tests/http/conftest.py index 81d99115f3..487a1328ea 100644 --- a/tests/http/conftest.py +++ b/tests/http/conftest.py @@ -30,7 +30,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]: importlib.import_module(f".{module}", package="tests.http.clients"), client, ) - except ImportError: # noqa: PERF203 + except ImportError: client_class = None yield pytest.param( diff --git a/tests/websockets/conftest.py b/tests/websockets/conftest.py index 4960c5bde7..00cfa6ac14 100644 --- a/tests/websockets/conftest.py +++ b/tests/websockets/conftest.py @@ -18,7 +18,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]: client_class = getattr( importlib.import_module(f"tests.http.clients.{module}"), client ) - except ImportError: # noqa: PERF203 + except ImportError: client_class = None yield pytest.param(