Skip to content

Commit

Permalink
refactor(backend/mixins): make AggregateQuerysetMixin more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
c0rydoras committed Jun 17, 2024
1 parent fba36de commit 8c44ab4
Showing 1 changed file with 37 additions and 40 deletions.
77 changes: 37 additions & 40 deletions backend/timed/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
from timed.serializers import AggregateObject


def is_related_field(field: relations.Field) -> bool:
"""Check whether value is a related field.
Ignores serializer method fields which define logic separately.
"""
return isinstance(field, relations.ResourceRelatedField) and not isinstance(
field, relations.ManySerializerMethodResourceRelatedField
)


class AggregateQuerysetMixin:
"""Add support for aggregate queryset in view.
Expand All @@ -18,59 +28,43 @@ class AggregateQuerysetMixin:
Mixin expects the id to be the same key as the resource related
field defined in the serializer.
To reduce number of queries `prefetch_related_for_field` can be defined
to prefetch related data per field like the following:
>>> from rest_framework.viewsets import ReadOnlyModelViewSet
...
...
... class MyView(ReadOnlyModelViewSet, AggregateQuerysetMixin):
... # ...
... prefetch_related_for_field = {"field_name": ["field_name_prefetch"]}
... # ...
"""

def _is_related_field(self, val):
"""Check whether value is a related field.
Ignores serializer method fields which define logic separately.
"""
return isinstance(val, relations.ResourceRelatedField) and not isinstance(
val, relations.ManySerializerMethodResourceRelatedField
)

def get_serializer(self, data=None, *args, **kwargs):
def _get_data(self, data, *args, **kwargs):
# no data no wrapping needed
if not data:
return super().get_serializer(data, *args, **kwargs)
return data

many = kwargs.get("many")
if not many:
data = [data]

# prefetch data for all related fields

prefetch_per_field = {}
serializer_class = self.get_serializer_class()
for key, value in serializer_class._declared_fields.items(): # noqa: SLF001
if self._is_related_field(value):
source = value.source or key
if many:
obj_ids = data.values_list(source, flat=True)
else:
obj_ids = [data[0][source]]

qs = value.model.objects.filter(id__in=obj_ids)
qs = qs.select_related()
if hasattr(self, "prefetch_related_for_field"): # pragma: no cover
qs = qs.prefetch_related(
*self.prefetch_related_for_field.get(source, [])
)

objects = {obj.id: obj for obj in qs}
prefetch_per_field[source] = objects

lookup_expr = "id__in" if many else "id"

for key, value in filter(
lambda kv: is_related_field(kv[1]),
serializer_class._declared_fields.items(), # noqa: SLF001
):
source = value.source or key

lookup_value = data.values_list(source, flat=True) if many else data[source]

qs = value.model.objects.filter(**{lookup_expr: lookup_value})
qs = qs.select_related()

prefetch_per_field[source] = {obj.id: obj for obj in qs}

# enhance entry dicts with model instances
data = [
AggregateObject(
def _construct_aggregate_object(entry):
return AggregateObject(
**{
**entry,
**{
Expand All @@ -79,10 +73,13 @@ def get_serializer(self, data=None, *args, **kwargs):
},
}
)
for entry in data
]

if not many:
data = data[0]
return _construct_aggregate_object(data)

return [_construct_aggregate_object(entry) for entry in data]

return super().get_serializer(data, *args, **kwargs)
def get_serializer(self, data=None, *args, **kwargs):
return super().get_serializer(
self._get_data(data, *args, **kwargs), *args, **kwargs
)

0 comments on commit 8c44ab4

Please sign in to comment.