3
3
from timed .serializers import AggregateObject
4
4
5
5
6
+ def is_related_field (field : relations .Field ) -> bool :
7
+ """Check whether value is a related field.
8
+
9
+ Ignores serializer method fields which define logic separately.
10
+ """
11
+ return isinstance (field , relations .ResourceRelatedField ) and not isinstance (
12
+ field , relations .ManySerializerMethodResourceRelatedField
13
+ )
14
+
15
+
6
16
class AggregateQuerysetMixin :
7
17
"""Add support for aggregate queryset in view.
8
18
@@ -18,59 +28,42 @@ class AggregateQuerysetMixin:
18
28
Mixin expects the id to be the same key as the resource related
19
29
field defined in the serializer.
20
30
21
- To reduce number of queries `prefetch_related_for_field` can be defined
22
- to prefetch related data per field like the following:
23
31
>>> from rest_framework.viewsets import ReadOnlyModelViewSet
24
32
...
25
33
...
26
34
... class MyView(ReadOnlyModelViewSet, AggregateQuerysetMixin):
27
35
... # ...
28
- ... prefetch_related_for_field = {"field_name": ["field_name_prefetch"]}
29
- ... # ...
30
36
"""
31
37
32
- def _is_related_field (self , val ):
33
- """Check whether value is a related field.
34
-
35
- Ignores serializer method fields which define logic separately.
36
- """
37
- return isinstance (val , relations .ResourceRelatedField ) and not isinstance (
38
- val , relations .ManySerializerMethodResourceRelatedField
39
- )
40
-
41
- def get_serializer (self , data = None , * args , ** kwargs ):
42
- # no data no wrapping needed
38
+ def _get_data (self , data , * args , ** kwargs ):
43
39
if not data :
44
- return super (). get_serializer ( data , * args , ** kwargs )
40
+ return data
45
41
46
42
many = kwargs .get ("many" )
47
- if not many :
48
- data = [data ]
49
43
50
44
# prefetch data for all related fields
45
+
51
46
prefetch_per_field = {}
52
47
serializer_class = self .get_serializer_class ()
53
- for key , value in serializer_class ._declared_fields .items (): # noqa: SLF001
54
- if self ._is_related_field (value ):
55
- source = value .source or key
56
- if many :
57
- obj_ids = data .values_list (source , flat = True )
58
- else :
59
- obj_ids = [data [0 ][source ]]
60
-
61
- qs = value .model .objects .filter (id__in = obj_ids )
62
- qs = qs .select_related ()
63
- if hasattr (self , "prefetch_related_for_field" ): # pragma: no cover
64
- qs = qs .prefetch_related (
65
- * self .prefetch_related_for_field .get (source , [])
66
- )
67
-
68
- objects = {obj .id : obj for obj in qs }
69
- prefetch_per_field [source ] = objects
48
+
49
+ lookup_expr = "id__in" if many else "id"
50
+
51
+ for key , value in filter (
52
+ lambda kv : is_related_field (kv [1 ]),
53
+ serializer_class ._declared_fields .items (), # noqa: SLF001
54
+ ):
55
+ source = value .source or key
56
+
57
+ lookup_value = data .values_list (source , flat = True ) if many else data [source ]
58
+
59
+ qs = value .model .objects .filter (** {lookup_expr : lookup_value })
60
+ qs = qs .select_related ()
61
+
62
+ prefetch_per_field [source ] = {obj .id : obj for obj in qs }
70
63
71
64
# enhance entry dicts with model instances
72
- data = [
73
- AggregateObject (
65
+ def _construct_aggregate_object ( entry ):
66
+ return AggregateObject (
74
67
** {
75
68
** entry ,
76
69
** {
@@ -79,10 +72,15 @@ def get_serializer(self, data=None, *args, **kwargs):
79
72
},
80
73
}
81
74
)
82
- for entry in data
83
- ]
84
75
85
76
if not many :
86
- data = data [ 0 ]
77
+ return _construct_aggregate_object ( data )
87
78
88
- return super ().get_serializer (data , * args , ** kwargs )
79
+ return [_construct_aggregate_object (entry ) for entry in data ]
80
+
81
+ def get_serializer (self , data = None , * args , ** kwargs ):
82
+ # no data no wrapping needed
83
+
84
+ return super ().get_serializer (
85
+ self ._get_data (data , * args , ** kwargs ), * args , ** kwargs
86
+ )
0 commit comments