From a13c2d089f8db1bd5ff3b368dc61604ee57003e7 Mon Sep 17 00:00:00 2001 From: Roman Date: Wed, 10 Jan 2024 17:26:57 -0800 Subject: [PATCH 1/5] add support for GenericRelation field --- django_readers/qs.py | 35 +++++++++++++++++++++++++++++++++++ tests/models.py | 14 ++++++++++++++ tests/test_rest_framework.py | 28 +++++++++++++++++++++++++--- 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/django_readers/qs.py b/django_readers/qs.py index 59038f4..4dd5c73 100644 --- a/django_readers/qs.py +++ b/django_readers/qs.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor from django.db.models import Prefetch, QuerySet from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related_descriptors import ( @@ -167,6 +168,32 @@ def prefetch_reverse_relationship( ) +def prefetch_reverse_generic_relationship( + name, related_field, related_queryset, prepare_related_queryset=noop, to_attr=None +): + """ + Efficiently prefetch a reverse generic relationship: one where the field on the "parent" + queryset is a `GenericRelation` field. We need to include this field in the query. + """ + return pipe( + include_fields(name), + prefetch_related( + Prefetch( + name, + pipe( + include_fields( + "pk", + related_field.content_type_field_name, + related_field.object_id_field_name, + ), + prepare_related_queryset, + )(related_queryset), + to_attr, + ) + ), + ) + + def prefetch_many_to_many_relationship( name, related_queryset, prepare_related_queryset=noop, to_attr=None ): @@ -246,5 +273,13 @@ def prepare(queryset): prepare_related_queryset, to_attr, )(queryset) + if type(related_descriptor) is ReverseGenericManyToOneDescriptor: + return prefetch_reverse_generic_relationship( + name, + related_descriptor.rel.field, + related_descriptor.field.related_model.objects.all(), + prepare_related_queryset, + to_attr, + )(queryset) return prepare diff --git a/tests/models.py b/tests/models.py index 72b0494..12e9fbd 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,17 @@ +from django.contrib.contenttypes.fields import GenericRelation from django.db import models +class LogEntry(models.Model): + content_type = models.ForeignKey( + to="contenttypes.ContentType", + on_delete=models.CASCADE, + related_name="+", + ) + object_pk = models.CharField(max_length=255) + event = models.CharField(max_length=100) + + class Group(models.Model): name = models.CharField(max_length=100) @@ -15,6 +26,9 @@ class Widget(models.Model): value = models.PositiveIntegerField(default=0) other = models.CharField(max_length=100, null=True) owner = models.ForeignKey(Owner, null=True, on_delete=models.SET_NULL) + logs = GenericRelation( + LogEntry, content_type_field="content_type", object_id_field="object_pk" + ) class Thing(models.Model): diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index 872ebf7..d244910 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ImproperlyConfigured from django.test import TestCase from django_readers import pairs, qs @@ -10,7 +11,7 @@ from rest_framework import serializers from rest_framework.generics import ListAPIView, RetrieveAPIView from rest_framework.test import APIRequestFactory -from tests.models import Category, Group, Owner, Widget +from tests.models import Category, Group, LogEntry, Owner, Widget from textwrap import dedent @@ -28,6 +29,7 @@ class WidgetListView(SpecMixin, ListAPIView): }, ] }, + {"logs": ["event"]}, ] @@ -53,17 +55,32 @@ class CategoryDetailView(SpecMixin, RetrieveAPIView): class RESTFrameworkTestCase(TestCase): def test_list(self): - Widget.objects.create( + widget = Widget.objects.create( name="test widget", owner=Owner.objects.create( name="test owner", group=Group.objects.create(name="test group") ), ) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="UPDATED", + ) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="DELETED", + ) request = APIRequestFactory().get("/") view = WidgetListView.as_view() - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = view(request) self.assertEqual( @@ -77,6 +94,11 @@ def test_list(self): "name": "test group", }, }, + "logs": [ + {"event": "CREATED"}, + {"event": "UPDATED"}, + {"event": "DELETED"}, + ], } ], ) From 7c76d8e31b5b8fda8274be7e0d71bcc6a133fd9e Mon Sep 17 00:00:00 2001 From: Jamie Matthews Date: Fri, 12 Jan 2024 08:52:54 +0000 Subject: [PATCH 2/5] Change API for prefetch_reverse_generic_relationship slightly and add queryset-level tests --- django_readers/qs.py | 14 +++++-- tests/test_qs.py | 87 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 5 deletions(-) diff --git a/django_readers/qs.py b/django_readers/qs.py index 4dd5c73..3cad769 100644 --- a/django_readers/qs.py +++ b/django_readers/qs.py @@ -169,7 +169,12 @@ def prefetch_reverse_relationship( def prefetch_reverse_generic_relationship( - name, related_field, related_queryset, prepare_related_queryset=noop, to_attr=None + name, + content_type_field_name, + object_id_field_name, + related_queryset, + prepare_related_queryset=noop, + to_attr=None, ): """ Efficiently prefetch a reverse generic relationship: one where the field on the "parent" @@ -183,8 +188,8 @@ def prefetch_reverse_generic_relationship( pipe( include_fields( "pk", - related_field.content_type_field_name, - related_field.object_id_field_name, + content_type_field_name, + object_id_field_name, ), prepare_related_queryset, )(related_queryset), @@ -276,7 +281,8 @@ def prepare(queryset): if type(related_descriptor) is ReverseGenericManyToOneDescriptor: return prefetch_reverse_generic_relationship( name, - related_descriptor.rel.field, + related_descriptor.rel.field.content_type_field_name, + related_descriptor.rel.field.object_id_field_name, related_descriptor.field.related_model.objects.all(), prepare_related_queryset, to_attr, diff --git a/tests/test_qs.py b/tests/test_qs.py index af32920..969b940 100644 --- a/tests/test_qs.py +++ b/tests/test_qs.py @@ -1,9 +1,10 @@ +from django.contrib.contenttypes.models import ContentType from django.db import connection from django.db.models import Count from django.test import TestCase from django.test.utils import CaptureQueriesContext from django_readers import qs -from tests.models import Category, Owner, Widget +from tests.models import Category, LogEntry, Owner, Widget from unittest import mock @@ -188,6 +189,84 @@ def test_prefetch_reverse_relationship(self): with self.assertNumQueries(0): self.assertEqual(owners[0].widget_set.all()[0].name, "test widget") + def test_prefetch_reverse_generic_relationship(self): + widget = Widget.objects.create(name="test widget") + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) + + prepare = qs.pipe( + qs.include_fields("name"), + qs.prefetch_reverse_generic_relationship( + "logs", + "content_type", + "object_pk", + LogEntry.objects.all(), + qs.include_fields("event"), + ), + ) + + with CaptureQueriesContext(connection) as capture: + widgets = list(prepare(Widget.objects.all())) + + self.assertEqual(len(capture.captured_queries), 2) + + self.assertEqual( + capture.captured_queries[0]["sql"], + "SELECT " + '"tests_widget"."id", ' + '"tests_widget"."name" ' + "FROM " + '"tests_widget"', + ) + + content_type_id = ContentType.objects.get_for_model(Widget).pk + + self.assertEqual( + capture.captured_queries[1]["sql"], + "SELECT " + '"tests_logentry"."id", ' + '"tests_logentry"."content_type_id", ' + '"tests_logentry"."object_pk", ' + '"tests_logentry"."event" ' + "FROM " + '"tests_logentry" ' + "WHERE " + f'("tests_logentry"."content_type_id" = {content_type_id} AND ' + '"tests_logentry"."object_pk" IN ' + "('1'))", + ) + + with self.assertNumQueries(0): + self.assertEqual(widgets[0].logs.all()[0].event, "CREATED") + + def test_prefetch_reverse_generic_relationship_with_to_attr(self): + widget = Widget.objects.create(name="test widget") + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) + + prepare = qs.pipe( + qs.include_fields("name"), + qs.prefetch_reverse_generic_relationship( + "logs", + "content_type", + "object_pk", + LogEntry.objects.all(), + qs.include_fields("event"), + to_attr="history", + ), + ) + + widgets = list(prepare(Widget.objects.all())) + + with self.assertNumQueries(0): + self.assertEqual(widgets[0].history[0].event, "CREATED") + def test_prefetch_reverse_relationship_only_loads_pk_and_related_name_by_default( self, ): @@ -358,6 +437,12 @@ def test_auto_prefetch_relationship(self): qs.auto_prefetch_relationship("category_set")(Widget.objects.all()) mock_fn.assert_called_once() + with mock.patch( + "django_readers.qs.prefetch_reverse_generic_relationship" + ) as mock_fn: + qs.auto_prefetch_relationship("logs")(Widget.objects.all()) + mock_fn.assert_called_once() + def test_annotate_only_includes_fk_by_default(self): owner = Owner.objects.create(name="test owner") Widget.objects.create(name="test 1", owner=owner) From 0269276b0e9a669745fbdcc6d631858d4a5d5ff2 Mon Sep 17 00:00:00 2001 From: Jamie Matthews Date: Fri, 12 Jan 2024 08:56:02 +0000 Subject: [PATCH 3/5] Add generic relation example to spec tests --- tests/test_specs.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_specs.py b/tests/test_specs.py index 31d6196..8805133 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -1,6 +1,7 @@ +from django.contrib.contenttypes.models import ContentType from django.test import TestCase from django_readers import specs -from tests.models import Category, Group, Owner, Thing, Widget +from tests.models import Category, Group, LogEntry, Owner, Thing, Widget class SpecTestCase(TestCase): @@ -28,6 +29,11 @@ def test_relationships(self): category = Category.objects.create(name="test category") category.widget_set.add(widget) Thing.objects.create(name="test thing", widget=widget) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) prepare, project = specs.process( [ @@ -35,13 +41,14 @@ def test_relationships(self): {"owner": ["name", {"widget_set": ["name"]}]}, {"category_set": ["name", {"widget_set": ["name"]}]}, {"thing": ["name", {"widget": ["name"]}]}, + {"logs": ["event"]}, ] ) with self.assertNumQueries(0): queryset = prepare(Widget.objects.all()) - with self.assertNumQueries(7): + with self.assertNumQueries(8): instance = queryset.first() with self.assertNumQueries(0): @@ -59,6 +66,7 @@ def test_relationships(self): {"name": "test category", "widget_set": [{"name": "test widget"}]}, ], "thing": {"name": "test thing", "widget": {"name": "test widget"}}, + "logs": [{"event": "CREATED"}], }, ) From ce2d16f06969822b0e9e7f322a3b5b5cc9192892 Mon Sep 17 00:00:00 2001 From: Jamie Matthews Date: Fri, 12 Jan 2024 09:32:37 +0000 Subject: [PATCH 4/5] Add support for GenericRelation fields to serializer generation --- django_readers/rest_framework.py | 20 ++++++++++++++++++-- tests/test_rest_framework.py | 10 ++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/django_readers/rest_framework.py b/django_readers/rest_framework.py index 29b791c..de1eb22 100644 --- a/django_readers/rest_framework.py +++ b/django_readers/rest_framework.py @@ -1,4 +1,5 @@ from copy import deepcopy +from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor from django.core.exceptions import ImproperlyConfigured from django.utils.functional import cached_property from django_readers import specs @@ -124,10 +125,25 @@ def _get_child_serializer_kwargs(self, rel_info): kwargs["allow_null"] = True return kwargs + def _get_rel_info(self, rel_name): + descriptor = getattr(self.model, rel_name) + # Special case for reverse generic relations (GenericRelation field) + # as these don't appear in rest-framework's rel_info + if isinstance(descriptor, ReverseGenericManyToOneDescriptor): + return model_meta.RelationInfo( + model_field=descriptor.field, + related_model=descriptor.field.related_model, + to_many=True, + to_field=None, + has_through_model=False, + reverse=True, + ) + return self.info.relations[rel_name] + def visit_dict_item_list(self, key, value): # This is a relationship, so we recurse and create # a nested serializer to represent it - rel_info = self.info.relations[key] + rel_info = self._get_rel_info(key) capfirst = self._lowercase_with_underscores_to_capitalized_words(key) child_serializer_class = serializer_class_for_spec( f"{self.name}{capfirst}", @@ -143,7 +159,7 @@ def visit_dict_item_dict(self, key, value): # do the same as the previous case, but handled # slightly differently to set the `source` correctly relationship_name, relationship_spec = next(iter(value.items())) - rel_info = self.info.relations[relationship_name] + rel_info = self._get_rel_info(relationship_name) capfirst = self._lowercase_with_underscores_to_capitalized_words(key) child_serializer_class = serializer_class_for_spec( f"{self.name}{capfirst}", diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index d244910..546f43d 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -202,12 +202,16 @@ def test_all_relationship_types(self): }, ] }, + { + "logs": [ + "event", + ] + }, ] }, ] cls = serializer_class_for_spec("Owner", Owner, spec) - expected = dedent( """\ OwnerSerializer(): @@ -221,7 +225,9 @@ def test_all_relationship_types(self): thing = OwnerWidgetSetThingSerializer(read_only=True): name = CharField(max_length=100, read_only=True) related_widget = OwnerWidgetSetThingRelatedWidgetSerializer(allow_null=True, read_only=True, source='widget'): - name = CharField(allow_null=True, max_length=100, read_only=True, required=False)""" + name = CharField(allow_null=True, max_length=100, read_only=True, required=False) + logs = OwnerWidgetSetLogsSerializer(allow_null=True, many=True, read_only=True): + event = CharField(max_length=100, read_only=True)""" ) self.assertEqual(repr(cls()), expected) From de0549719aafb921a887f96351fe7592e678a7e6 Mon Sep 17 00:00:00 2001 From: Jamie Matthews Date: Fri, 12 Jan 2024 09:36:01 +0000 Subject: [PATCH 5/5] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e966cd4..f851fbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Added support for Django's reverse generic relations (`GenericRelation` model field) ([#93](https://github.com/dabapps/django-readers/pull/93)). + ## [2.1.2] - 2023-07-17 ### Fixed