Skip to content

Commit 1a3c9fc

Browse files
committed
refactor(backend/mixins): make AggregateQuerysetMixin more readable
1 parent e26f4d3 commit 1a3c9fc

File tree

1 file changed

+37
-40
lines changed

1 file changed

+37
-40
lines changed

backend/timed/mixins.py

+37-40
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
from timed.serializers import AggregateObject
44

55

6+
def is_related_field(field: relations.Field) -> bool:
7+
"""Check whether value is a related field.
8+
9+
Ignores serializer method fields which define logic separately.
10+
"""
11+
return isinstance(field, relations.ResourceRelatedField) and not isinstance(
12+
field, relations.ManySerializerMethodResourceRelatedField
13+
)
14+
15+
616
class AggregateQuerysetMixin:
717
"""Add support for aggregate queryset in view.
818
@@ -18,59 +28,43 @@ class AggregateQuerysetMixin:
1828
Mixin expects the id to be the same key as the resource related
1929
field defined in the serializer.
2030
21-
To reduce number of queries `prefetch_related_for_field` can be defined
22-
to prefetch related data per field like the following:
2331
>>> from rest_framework.viewsets import ReadOnlyModelViewSet
2432
...
2533
...
2634
... class MyView(ReadOnlyModelViewSet, AggregateQuerysetMixin):
2735
... # ...
28-
... prefetch_related_for_field = {"field_name": ["field_name_prefetch"]}
29-
... # ...
3036
"""
3137

32-
def _is_related_field(self, val):
33-
"""Check whether value is a related field.
34-
35-
Ignores serializer method fields which define logic separately.
36-
"""
37-
return isinstance(val, relations.ResourceRelatedField) and not isinstance(
38-
val, relations.ManySerializerMethodResourceRelatedField
39-
)
40-
41-
def get_serializer(self, data=None, *args, **kwargs):
38+
def _get_data(self, data, *args, **kwargs):
4239
# no data no wrapping needed
4340
if not data:
44-
return super().get_serializer(data, *args, **kwargs)
41+
return data
4542

4643
many = kwargs.get("many")
47-
if not many:
48-
data = [data]
4944

5045
# prefetch data for all related fields
46+
5147
prefetch_per_field = {}
5248
serializer_class = self.get_serializer_class()
53-
for key, value in serializer_class._declared_fields.items(): # noqa: SLF001
54-
if self._is_related_field(value):
55-
source = value.source or key
56-
if many:
57-
obj_ids = data.values_list(source, flat=True)
58-
else:
59-
obj_ids = [data[0][source]]
60-
61-
qs = value.model.objects.filter(id__in=obj_ids)
62-
qs = qs.select_related()
63-
if hasattr(self, "prefetch_related_for_field"): # pragma: no cover
64-
qs = qs.prefetch_related(
65-
*self.prefetch_related_for_field.get(source, [])
66-
)
67-
68-
objects = {obj.id: obj for obj in qs}
69-
prefetch_per_field[source] = objects
49+
50+
lookup_expr = "id__in" if many else "id"
51+
52+
for key, value in filter(
53+
lambda kv: is_related_field(kv[1]),
54+
serializer_class._declared_fields.items(), # noqa: SLF001
55+
):
56+
source = value.source or key
57+
58+
lookup_value = data.values_list(source, flat=True) if many else data[source]
59+
60+
qs = value.model.objects.filter(**{lookup_expr: lookup_value})
61+
qs = qs.select_related()
62+
63+
prefetch_per_field[source] = {obj.id: obj for obj in qs}
7064

7165
# enhance entry dicts with model instances
72-
data = [
73-
AggregateObject(
66+
def _construct_aggregate_object(entry):
67+
return AggregateObject(
7468
**{
7569
**entry,
7670
**{
@@ -79,10 +73,13 @@ def get_serializer(self, data=None, *args, **kwargs):
7973
},
8074
}
8175
)
82-
for entry in data
83-
]
8476

8577
if not many:
86-
data = data[0]
78+
return _construct_aggregate_object(data)
79+
80+
return [_construct_aggregate_object(entry) for entry in data]
8781

88-
return super().get_serializer(data, *args, **kwargs)
82+
def get_serializer(self, data=None, *args, **kwargs):
83+
return super().get_serializer(
84+
self._get_data(data, *args, **kwargs), *args, **kwargs
85+
)

0 commit comments

Comments
 (0)