Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(backend/mixins): make AggregateQuerysetMixin more readable #148

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 37 additions & 40 deletions backend/timed/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
from timed.serializers import AggregateObject


def is_related_field(field: relations.Field) -> bool:
"""Check whether value is a related field.

Ignores serializer method fields which define logic separately.
"""
return isinstance(field, relations.ResourceRelatedField) and not isinstance(
field, relations.ManySerializerMethodResourceRelatedField
)


class AggregateQuerysetMixin:
"""Add support for aggregate queryset in view.

Expand All @@ -18,59 +28,43 @@ class AggregateQuerysetMixin:
Mixin expects the id to be the same key as the resource related
field defined in the serializer.

To reduce number of queries `prefetch_related_for_field` can be defined
to prefetch related data per field like the following:
>>> from rest_framework.viewsets import ReadOnlyModelViewSet
...
...
... class MyView(ReadOnlyModelViewSet, AggregateQuerysetMixin):
... # ...
... prefetch_related_for_field = {"field_name": ["field_name_prefetch"]}
... # ...
"""

def _is_related_field(self, val):
"""Check whether value is a related field.

Ignores serializer method fields which define logic separately.
"""
return isinstance(val, relations.ResourceRelatedField) and not isinstance(
val, relations.ManySerializerMethodResourceRelatedField
)

def get_serializer(self, data=None, *args, **kwargs):
def _get_data(self, data, *args, **kwargs):
# no data no wrapping needed
if not data:
return super().get_serializer(data, *args, **kwargs)
return data

many = kwargs.get("many")
if not many:
data = [data]

# prefetch data for all related fields

prefetch_per_field = {}
serializer_class = self.get_serializer_class()
for key, value in serializer_class._declared_fields.items(): # noqa: SLF001
if self._is_related_field(value):
source = value.source or key
if many:
obj_ids = data.values_list(source, flat=True)
else:
obj_ids = [data[0][source]]

qs = value.model.objects.filter(id__in=obj_ids)
qs = qs.select_related()
if hasattr(self, "prefetch_related_for_field"): # pragma: no cover
qs = qs.prefetch_related(
*self.prefetch_related_for_field.get(source, [])
)

objects = {obj.id: obj for obj in qs}
prefetch_per_field[source] = objects

lookup_expr = "id__in" if many else "id"

for key, value in filter(
lambda kv: is_related_field(kv[1]),
serializer_class._declared_fields.items(), # noqa: SLF001
):
source = value.source or key

lookup_value = data.values_list(source, flat=True) if many else data[source]

qs = value.model.objects.filter(**{lookup_expr: lookup_value})
qs = qs.select_related()

prefetch_per_field[source] = {obj.id: obj for obj in qs}

# enhance entry dicts with model instances
data = [
AggregateObject(
def _construct_aggregate_object(entry):
return AggregateObject(
**{
**entry,
**{
Expand All @@ -79,10 +73,13 @@ def get_serializer(self, data=None, *args, **kwargs):
},
}
)
for entry in data
]

if not many:
data = data[0]
return _construct_aggregate_object(data)

return [_construct_aggregate_object(entry) for entry in data]

return super().get_serializer(data, *args, **kwargs)
def get_serializer(self, data=None, *args, **kwargs):
return super().get_serializer(
self._get_data(data, *args, **kwargs), *args, **kwargs
)
Loading