Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #660

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
- id: check-xml
- id: check-symlinks
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
rev: v0.8.1
hooks:
- id: ruff-format
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion strawberry_django/auth/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def resolve_logout(info: Info) -> bool:

class DjangoRegisterMutation(DjangoCreateMutation):
def create(self, data: dict[str, Any], *, info: Info):
model = cast(type["AbstractBaseUser"], self.django_model)
model = cast("type[AbstractBaseUser]", self.django_model)
assert model is not None

password = data.pop("password")
Expand Down
4 changes: 2 additions & 2 deletions strawberry_django/extensions/django_cache_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def __init__(
*,
execution_context: Optional[ExecutionContext] = None,
):
super().__init__(execution_context=cast(ExecutionContext, execution_context))
super().__init__(execution_context=cast("ExecutionContext", execution_context))

self.cache = caches[cache_name]
self.timeout = timeout or DEFAULT_TIMEOUT
# Use same key generating function as functools.lru_cache as default
self.hash_fn = hash_fn or (lambda args, kwargs: _make_key(args, kwargs, False))

def execute_cached(self, func, *args, **kwargs):
hash_key = cast(str, self.hash_fn(args, kwargs))
hash_key = cast("str", self.hash_fn(args, kwargs))
cache_result = self.cache.get(hash_key)
if cache_result is not None:
return cache_result
Expand Down
2 changes: 1 addition & 1 deletion strawberry_django/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None:
object_definition.origin, (relay.Connection, OffsetPaginated)
):
origin_specialized_type_var_map = (
get_specialized_type_var_map(cast(type, origin)) or {}
get_specialized_type_var_map(cast("type", origin)) or {}
)
origin = origin_specialized_type_var_map.get("NodeType")

Expand Down
14 changes: 7 additions & 7 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from strawberry.extensions.field_extension import FieldExtension
from strawberry.types.field import _RESOLVER_TYPE # noqa: PLC2701
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.types.info import Info # noqa: TCH002
from strawberry.types.info import Info # noqa: TC002
from strawberry.utils.await_maybe import await_maybe
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -376,7 +376,7 @@ def default_resolver(
# If this is a nested field, call get_result instead because we want
# to retrieve the queryset from its RelatedManager
retval = cast(
models.QuerySet,
"models.QuerySet",
getattr(root, field.django_name or field.python_name).all(),
)
else:
Expand All @@ -394,7 +394,7 @@ def default_resolver(
required=True,
)

return cast(Iterable[Any], retval)
return cast("Iterable[Any]", retval)

default_resolver.can_optimize = True # type: ignore

Expand All @@ -415,7 +415,7 @@ def resolve(
**kwargs: Any,
) -> Any:
assert self.connection_type is not None
nodes = cast(Iterable[relay.Node], next_(source, info, **kwargs))
nodes = cast("Iterable[relay.Node]", next_(source, info, **kwargs))

# We have a single resolver for both sync and async, so we need to check if
# nodes is awaitable or not and resolve it accordingly
Expand Down Expand Up @@ -457,7 +457,7 @@ def apply(self, field: StrawberryField) -> None:
)

field.arguments = _get_field_arguments_for_extensions(field)
self.paginated_type = cast(type[OffsetPaginated], field.type)
self.paginated_type = cast("type[OffsetPaginated]", field.type)

def resolve(
self,
Expand All @@ -471,10 +471,10 @@ def resolve(
**kwargs: Any,
) -> Any:
assert self.paginated_type is not None
queryset = cast(models.QuerySet, next_(source, info, **kwargs))
queryset = cast("models.QuerySet", next_(source, info, **kwargs))

def get_queryset(queryset):
return cast(StrawberryDjangoField, info._field).get_queryset(
return cast("StrawberryDjangoField", info._field).get_queryset(
queryset,
info,
pagination=pagination,
Expand Down
8 changes: 4 additions & 4 deletions strawberry_django/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def process_filters(
assert has_object_definition(field_value)

queryset, sub_q = process_filters(
cast(WithStrawberryObjectDefinition, field_value),
cast("WithStrawberryObjectDefinition", field_value),
queryset,
info,
prefix,
Expand Down Expand Up @@ -231,7 +231,7 @@ def process_filters(
queryset = _process_deprecated_filter(filter_method, info, queryset)
elif has_object_definition(field_value):
queryset, sub_q = process_filters(
cast(WithStrawberryObjectDefinition, field_value),
cast("WithStrawberryObjectDefinition", field_value),
queryset,
info,
f"{prefix}{field_name}__",
Expand Down Expand Up @@ -264,7 +264,7 @@ def apply(
return queryset

queryset, q = process_filters(
cast(WithStrawberryObjectDefinition, filters), queryset, info
cast("WithStrawberryObjectDefinition", filters), queryset, info
)
if q:
queryset = queryset.filter(q)
Expand All @@ -289,7 +289,7 @@ def arguments(self) -> list[StrawberryArgument]:
arguments = []
if self.base_resolver is None:
filters = self.get_filters()
origin = cast(WithStrawberryObjectDefinition, self.origin)
origin = cast("WithStrawberryObjectDefinition", self.origin)
is_root_query = origin.__strawberry_definition__.name == "Query"

if (
Expand Down
6 changes: 3 additions & 3 deletions strawberry_django/integrations/guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def get_object_permission_models(
model: Union[models.Model, type[models.Model]],
) -> ObjectPermissionModels:
return ObjectPermissionModels(
user=cast(UserObjectPermissionBase, get_user_obj_perms_model(model)),
group=cast(GroupObjectPermissionBase, get_group_obj_perms_model(model)),
user=cast("UserObjectPermissionBase", get_user_obj_perms_model(model)),
group=cast("GroupObjectPermissionBase", get_group_obj_perms_model(model)),
)


def get_user_or_anonymous(user: UserType) -> UserType:
username = guardian_settings.ANONYMOUS_USER_NAME or ""
if user.is_anonymous and user.get_username() != username:
with contextlib.suppress(get_user_model().DoesNotExist):
return cast(UserType, _get_anonymous_user())
return cast("UserType", _get_anonymous_user())
return user
24 changes: 12 additions & 12 deletions strawberry_django/mutations/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def parse_input(
return data.resolve_node_sync(info, required=True)

if isinstance(data, NodeInput):
pk = cast(Any, parse_input(info, getattr(data, "id", UNSET), key_attr=key_attr))
pk = cast("Any", parse_input(info, getattr(data, "id", UNSET), key_attr=key_attr))
parsed = {}
for field in dataclasses.fields(data):
if field.name == "id":
Expand All @@ -172,7 +172,7 @@ def parse_input(

return ParsedObject(
pk=pk,
data=parsed if len(parsed) else None,
data=parsed or None,
)

if isinstance(data, (OneToOneInput, OneToManyInput)):
Expand All @@ -190,13 +190,13 @@ def parse_input(

return ParsedObjectList(
add=cast(
list[InputListTypes], parse_input(info, data.add, key_attr=key_attr)
"list[InputListTypes]", parse_input(info, data.add, key_attr=key_attr)
),
remove=cast(
list[InputListTypes], parse_input(info, data.remove, key_attr=key_attr)
"list[InputListTypes]", parse_input(info, data.remove, key_attr=key_attr)
),
set=cast(
list[InputListTypes], parse_input(info, data.set, key_attr=key_attr)
"list[InputListTypes]", parse_input(info, data.set, key_attr=key_attr)
),
)

Expand Down Expand Up @@ -263,7 +263,7 @@ def prepare_create_update(
):
value, value_data = _parse_data( # noqa: PLW2901
info,
cast(type[Model], field.related_model),
cast("type[Model]", field.related_model),
value,
key_attr=key_attr,
)
Expand Down Expand Up @@ -418,7 +418,7 @@ def update(
# Unwrap lazy objects since they have a proxy __iter__ method that will make
# them iterables even if the wrapped object isn't
if isinstance(instance, LazyObject):
instance = cast(_M, instance.__reduce__()[1][0])
instance = cast("_M", instance.__reduce__()[1][0])

if isinstance(instance, Iterable):
instances = list(instance)
Expand Down Expand Up @@ -482,7 +482,7 @@ def delete(info: Info, instance: _M | Iterable[_M], *, data=None) -> _M | list[_
# Unwrap lazy objects since they have a proxy __iter__ method that will make
# them iterables even if the wrapped object isn't
if isinstance(instance, LazyObject):
instance = cast(_M, instance.__reduce__()[1][0])
instance = cast("_M", instance.__reduce__()[1][0])

if isinstance(instance, Iterable):
many = True
Expand Down Expand Up @@ -511,7 +511,7 @@ def update_field(info: Info, instance: Model, field: models.Field, value: Any):
and isinstance(field, models.ForeignObject)
and not isinstance(value, Model)
):
value, data = _parse_pk(value, cast(type[Model], field.related_model))
value, data = _parse_pk(value, cast("type[Model]", field.related_model))

field.save_form_data(instance, value)
# If data was passed to the foreign key, update it recursively
Expand Down Expand Up @@ -578,7 +578,7 @@ def update_m2m(
need_remove_cache = need_remove_cache or bool(values)
for v in values:
obj, data = _parse_data(
info, cast(type[Model], manager.model), v, key_attr=key_attr
info, cast("type[Model]", manager.model), v, key_attr=key_attr
)
if obj:
data.pop(key_attr, None)
Expand Down Expand Up @@ -639,7 +639,7 @@ def update_m2m(
for v in value.add or []:
obj, data = _parse_data(
info,
cast(type[Model], manager.model),
cast("type[Model]", manager.model),
v,
key_attr=key_attr,
)
Expand All @@ -664,7 +664,7 @@ def update_m2m(
need_remove_cache = need_remove_cache or bool(value.remove)
for v in value.remove or []:
obj, data = _parse_data(
info, cast(type[Model], manager.model), v, key_attr=key_attr
info, cast("type[Model]", manager.model), v, key_attr=key_attr
)
data.pop(key_attr, None)
assert not data
Expand Down
12 changes: 6 additions & 6 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _optimize_prefetch_queryset(

mark_optimized = True

strawberry_schema = cast(Schema, info.schema._strawberry_schema) # type: ignore
strawberry_schema = cast("Schema", info.schema._strawberry_schema) # type: ignore
field_name = strawberry_schema.config.name_converter.from_field(field)
field_info = Info(
_raw_info=info,
Expand Down Expand Up @@ -597,7 +597,7 @@ def _get_selections(
info.schema,
info.fragments,
info.variable_values,
cast(GraphQLObjectType, parent_type),
cast("GraphQLObjectType", parent_type),
info.field_nodes,
)

Expand All @@ -613,7 +613,7 @@ def _generate_selection_resolve_info(
field_name=field_node.name.value,
field_nodes=field_nodes,
return_type=return_type,
parent_type=cast(GraphQLObjectType, parent_type),
parent_type=cast("GraphQLObjectType", parent_type),
path=info.path.add_key(0).add_key(field_node.name.value, parent_type.name),
schema=info.schema,
fragments=info.fragments,
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def _get_model_hints_from_connection(

e_definition = get_object_definition(relay.Edge, strict=True)
e_type = e_definition.resolve_generic(
relay.Edge[cast(type[relay.Node], n_type)],
relay.Edge[cast("type[relay.Node]", n_type)],
)
e_gql_definition = _get_gql_definition(
schema,
Expand Down Expand Up @@ -1268,7 +1268,7 @@ def optimize(

"""
if isinstance(qs, BaseManager):
qs = cast(QuerySet[_M], qs.all())
qs = cast("QuerySet[_M]", qs.all())

if isinstance(qs, list):
# return sliced queryset as-is
Expand All @@ -1283,7 +1283,7 @@ def optimize(

config = config or OptimizerConfig()
store = store or OptimizerStore()
schema = cast(Schema, info.schema._strawberry_schema) # type: ignore
schema = cast("Schema", info.schema._strawberry_schema) # type: ignore

gql_type = get_named_type(info.return_type)
strawberry_type = schema.get_type_by_name(gql_type.name)
Expand Down
2 changes: 1 addition & 1 deletion strawberry_django/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def parse_and_fill(field: ObjectValueNode, seq: dict[str, OrderSequence]):
parse_and_fill(arg.value, sequence)

queryset, args = process_order(
cast(WithStrawberryObjectDefinition, order), info, queryset, sequence=sequence
cast("WithStrawberryObjectDefinition", order), info, queryset, sequence=sequence
)
if not args:
return queryset
Expand Down
2 changes: 1 addition & 1 deletion strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def results(self) -> list[NodeType]:
paginated_queryset = self.get_paginated_queryset()

return cast(
list[NodeType], paginated_queryset if paginated_queryset is not None else []
"list[NodeType]", paginated_queryset if paginated_queryset is not None else []
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion strawberry_django/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def perm_checker(perm: PermDefinition) -> bool:
return (
user.has_perm(perm.perm) # type: ignore
if perm.permission
else user.has_module_perms(cast(str, perm.app)) # type: ignore
else user.has_module_perms(cast("str", perm.app)) # type: ignore
)

return perm_checker
Expand Down
Loading
Loading