Skip to content

Commit 6eaf1ba

Browse files
committed
refactor(backend/mixins): make AggregateQuerysetMixin more readable
1 parent b07c8ea commit 6eaf1ba

File tree

1 file changed

+39
-41
lines changed

1 file changed

+39
-41
lines changed

backend/timed/mixins.py

+39-41
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
from timed.serializers import AggregateObject
44

55

6+
def is_related_field(field: relations.Field) -> bool:
7+
"""Check whether value is a related field.
8+
9+
Ignores serializer method fields which define logic separately.
10+
"""
11+
return isinstance(field, relations.ResourceRelatedField) and not isinstance(
12+
field, relations.ManySerializerMethodResourceRelatedField
13+
)
14+
15+
616
class AggregateQuerysetMixin:
717
"""Add support for aggregate queryset in view.
818
@@ -18,59 +28,42 @@ class AggregateQuerysetMixin:
1828
Mixin expects the id to be the same key as the resource related
1929
field defined in the serializer.
2030
21-
To reduce number of queries `prefetch_related_for_field` can be defined
22-
to prefetch related data per field like the following:
2331
>>> from rest_framework.viewsets import ReadOnlyModelViewSet
2432
...
2533
...
2634
... class MyView(ReadOnlyModelViewSet, AggregateQuerysetMixin):
2735
... # ...
28-
... prefetch_related_for_field = {"field_name": ["field_name_prefetch"]}
29-
... # ...
3036
"""
3137

32-
def _is_related_field(self, val):
33-
"""Check whether value is a related field.
34-
35-
Ignores serializer method fields which define logic separately.
36-
"""
37-
return isinstance(val, relations.ResourceRelatedField) and not isinstance(
38-
val, relations.ManySerializerMethodResourceRelatedField
39-
)
40-
41-
def get_serializer(self, data=None, *args, **kwargs):
42-
# no data no wrapping needed
38+
def _get_data(self, data, *args, **kwargs):
4339
if not data:
44-
return super().get_serializer(data, *args, **kwargs)
40+
return data
4541

4642
many = kwargs.get("many")
47-
if not many:
48-
data = [data]
4943

5044
# prefetch data for all related fields
45+
5146
prefetch_per_field = {}
5247
serializer_class = self.get_serializer_class()
53-
for key, value in serializer_class._declared_fields.items(): # noqa: SLF001
54-
if self._is_related_field(value):
55-
source = value.source or key
56-
if many:
57-
obj_ids = data.values_list(source, flat=True)
58-
else:
59-
obj_ids = [data[0][source]]
60-
61-
qs = value.model.objects.filter(id__in=obj_ids)
62-
qs = qs.select_related()
63-
if hasattr(self, "prefetch_related_for_field"): # pragma: no cover
64-
qs = qs.prefetch_related(
65-
*self.prefetch_related_for_field.get(source, [])
66-
)
67-
68-
objects = {obj.id: obj for obj in qs}
69-
prefetch_per_field[source] = objects
48+
49+
lookup_expr = "id__in" if many else "id"
50+
51+
for key, value in filter(
52+
lambda kv: is_related_field(kv[1]),
53+
serializer_class._declared_fields.items(), # noqa: SLF001
54+
):
55+
source = value.source or key
56+
57+
lookup_value = data.values_list(source, flat=True) if many else data[source]
58+
59+
qs = value.model.objects.filter(**{lookup_expr: lookup_value})
60+
qs = qs.select_related()
61+
62+
prefetch_per_field[source] = {obj.id: obj for obj in qs}
7063

7164
# enhance entry dicts with model instances
72-
data = [
73-
AggregateObject(
65+
def _construct_aggregate_object(entry):
66+
return AggregateObject(
7467
**{
7568
**entry,
7669
**{
@@ -79,10 +72,15 @@ def get_serializer(self, data=None, *args, **kwargs):
7972
},
8073
}
8174
)
82-
for entry in data
83-
]
8475

8576
if not many:
86-
data = data[0]
77+
return _construct_aggregate_object(data)
8778

88-
return super().get_serializer(data, *args, **kwargs)
79+
return [_construct_aggregate_object(entry) for entry in data]
80+
81+
def get_serializer(self, data=None, *args, **kwargs):
82+
# no data no wrapping needed
83+
84+
return super().get_serializer(
85+
self._get_data(data, *args, **kwargs), *args, **kwargs
86+
)

0 commit comments

Comments
 (0)