Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduce model field names from custom prefetches #473

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
99 changes: 75 additions & 24 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ def with_hints(
),
)

def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]:
custom_prefetches = []
for p in self.prefetch_related:
if isinstance(p, Callable):
assert_type(p, PrefetchCallable)
p = p(info) # noqa: PLW2901
bellini666 marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(p, Prefetch) and p.to_attr is not None:
diesieben07 marked this conversation as resolved.
Show resolved Hide resolved
custom_prefetches.append(p)
return custom_prefetches

def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo):
prefetch_related = []
for p in self.prefetch_related:
Expand Down Expand Up @@ -458,7 +469,8 @@ def _get_model_hints(
continue

# Add annotations from the field if they exist
field_store = getattr(field, "store", None)
field_store = cast(OptimizerStore | None, getattr(field, "store", None))
diesieben07 marked this conversation as resolved.
Show resolved Hide resolved
custom_prefetches: list[Prefetch] = []
if field_store is not None:
if (
len(field_store.annotate) == 1
Expand All @@ -477,20 +489,45 @@ def _get_model_hints(
store |= (
field_store.with_prefix(prefix, info=info) if prefix else field_store
)
custom_prefetches = field_store.get_custom_prefetches(info)

# Then from the model property if one is defined
model_attr = getattr(model, field.python_name, None)
if model_attr is not None and isinstance(model_attr, ModelProperty):
attr_store = model_attr.store
store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store
attr_store_prefetches = attr_store.get_custom_prefetches(info)
if attr_store_prefetches:
custom_prefetches.extend(attr_store_prefetches)

model_fieldname: str | None = None
model_field = None
# try to find the model field name in any custom prefetches
if custom_prefetches:
for prefetch in custom_prefetches:
prefetch_field = model_fields.get(prefetch.prefetch_through, None)
if prefetch_field:
if not model_field:
model_field = prefetch_field
model_fieldname = prefetch.prefetch_through
elif model_field != prefetch_field:
# we found more than one model field from the custom prefetches
# not much we can do here
model_field = None
model_fieldname = None
custom_prefetches = []
break

# Lastly, from the django field itself
model_fieldname: str = getattr(field, "django_name", None) or field.python_name
model_field = model_fields.get(model_fieldname, None)
if not model_fieldname:
model_fieldname = getattr(field, "django_name", None) or field.python_name
model_field = model_fields.get(model_fieldname, None)

if model_field is not None:
path = f"{prefix}{model_fieldname}"

if isinstance(model_field, (models.ForeignKey, OneToOneRel)):
if not custom_prefetches and isinstance(model_field, (models.ForeignKey, OneToOneRel)):
# only select_related if there is no custom prefetch
store.only.append(path)
store.select_related.append(path)

Expand All @@ -517,7 +554,7 @@ def _get_model_hints(
if f_store is not None:
cache.setdefault(f_model, []).append((level, f_store))
store |= f_store.with_prefix(path, info=info)
elif GenericForeignKey and isinstance(model_field, GenericForeignKey):
elif not custom_prefetches and GenericForeignKey and isinstance(model_field, GenericForeignKey):
# There's not much we can do to optimize generic foreign keys regarding
# only/select_related because they can be anything.
# Just prefetch_related them
Expand All @@ -530,7 +567,8 @@ def _get_model_hints(
if len(f_types) > 1:
# This might be a generic foreign key.
# In this case, just prefetch it
store.prefetch_related.append(model_fieldname)
if not custom_prefetches:
store.prefetch_related.append(model_fieldname)
elif len(f_types) == 1:
remote_field = model_field.remote_field
remote_model = remote_field.model
Expand Down Expand Up @@ -590,24 +628,37 @@ def _get_model_hints(

cache.setdefault(remote_model, []).append((level, f_store))

# If prefetch_custom_queryset is false, use _base_manager here
# instead of _default_manager because we are getting related
# objects, and not querying it directly. Else use the type's
# get_queryset and model's custom QuerySet.
base_qs = _get_prefetch_queryset(
remote_model,
field,
config,
info,
)
f_qs = f_store.apply(
base_qs,
info=info,
config=config,
)
f_prefetch = Prefetch(path, queryset=f_qs)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
if custom_prefetches:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): This block introduces a significant change in how prefetches are handled based on the presence of custom_prefetches. It's a complex addition that could benefit from a bit more inline documentation to explain the rationale behind this approach, especially for future maintainers.

for prefetch in custom_prefetches:
if prefetch.queryset is not None:
p_qs = prefetch.queryset
else:
p_qs = _get_prefetch_queryset(remote_model, field, config, info)
f_qs = f_store.apply(p_qs, info=info, config=config)
f_prefetch = Prefetch(prefetch.prefetch_through, f_qs, prefetch.to_attr)
if prefix:
f_prefetch.add_prefix(prefix)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
else:
# If prefetch_custom_queryset is false, use _base_manager here
# instead of _default_manager because we are getting related
# objects, and not querying it directly. Else use the type's
# get_queryset and model's custom QuerySet.
base_qs = _get_prefetch_queryset(
remote_model,
field,
config,
info,
)
f_qs = f_store.apply(
base_qs,
info=info,
config=config,
)
f_prefetch = Prefetch(path, queryset=f_qs)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
else:
store.only.append(path)

Expand Down
21 changes: 19 additions & 2 deletions tests/projects/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Annotated

import strawberry
from django.contrib.auth import get_user_model
from django.db import models
from django.db.models import Count, QuerySet
from django.db.models import Count, QuerySet, Prefetch
from django.utils.translation import gettext_lazy as _
from django_choices_field import TextChoicesField

Expand Down Expand Up @@ -54,6 +55,22 @@ class Status(models.TextChoices):
def is_small(self) -> bool:
return self._milestone_count < 3 # type: ignore

@model_property(
prefetch_related=lambda _: Prefetch(
"milestones",
to_attr="next_milestones_prop_pf",
queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date")
)
)
def next_milestones_property(self) -> list[Annotated['MilestoneType', strawberry.lazy('.schema')]]:
"""
The milestones for the project ordered by their due date
"""
if hasattr(self, 'next_milestones_prop_pf'):
return self.next_milestones_prop_pf
else:
return self.milestones.filter(due_date__isnull=False).order_by("due_date")


class Milestone(models.Model):
issues: "RelatedManager[Issue]"
Expand Down
28 changes: 27 additions & 1 deletion tests/projects/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import datetime
import decimal
from typing import Iterable, List, Optional, Type, cast
from typing import Iterable, List, Optional, Type, cast, Union

import strawberry
from django.contrib.auth import get_user_model
Expand Down Expand Up @@ -103,6 +103,24 @@ class ProjectType(relay.Node):
cost: strawberry.auto = strawberry_django.field(extensions=[IsAuthenticated()])
is_small: strawberry.auto

next_milestones_property: strawberry.auto

@strawberry_django.field(
prefetch_related=lambda _: Prefetch(
"milestones",
to_attr="next_milestones_pf",
queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date")
)
)
def next_milestones(self) -> "list[MilestoneType]":
"""
The milestones for the project ordered by their due date
"""
if hasattr(self, 'next_milestones_pf'):
return self.next_milestones_pf
else:
return self.milestones.filter(due_date__isnull=False).order_by("due_date")


@strawberry_django.filter(Milestone, lookups=True)
class MilestoneFilter:
Expand Down Expand Up @@ -296,6 +314,14 @@ class ProjectConnection(ListConnectionWithTotalCount[ProjectType]):
"""Project connection documentation."""


ProjectFeedItem = Annotated[Union[IssueType, MilestoneType], strawberry.union('ProjectFeedItem')]


@strawberry.type
class ProjectFeedConnection(relay.Connection[ProjectFeedItem]):
pass


@strawberry.type
class Query:
"""All available queries for this schema."""
Expand Down
90 changes: 90 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,96 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
assert res.data == {"project": e}


@pytest.mark.django_db(transaction=True)
def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient):
query = """
query TestQuery {
projectList {
id
nextMilestones {
id
name
project {
id
name
}
}
}
}
"""

expected = []
for p in ProjectFactory.create_batch(2):
p_res: dict[str, Any] = {
"id": to_base64("ProjectType", p.id),
"nextMilestones": [],
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
"name": m.name,
"project": {
"id": p_res["id"],
"name": p.name,
},
}
p_res["nextMilestones"].append(m_res)

assert len(expected) == 2

with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3):
res = gql_client.query(query)
assert res.data == {"projectList": expected}


@pytest.mark.django_db(transaction=True)
def test_query_prefetch_with_to_attr_model_property(db, gql_client: GraphQLTestClient):
query = """
query TestQuery {
projectList {
id
nextMilestonesProperty {
id
name
project {
id
name
}
}
}
}
"""

expected = []
for p in ProjectFactory.create_batch(2):
p_res: dict[str, Any] = {
"id": to_base64("ProjectType", p.id),
"nextMilestonesProperty": [],
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
"name": m.name,
"project": {
"id": p_res["id"],
"name": p.name,
},
}
p_res["nextMilestonesProperty"].append(m_res)

assert len(expected) == 2

with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3):
res = gql_client.query(query)
assert res.data == {"projectList": expected}


@pytest.mark.django_db(transaction=True)
def test_query_connection_with_resolver(db, gql_client: GraphQLTestClient):
query = """
Expand Down