Skip to content

Commit

Permalink
feat: rebuild aggregation serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Legtenborg committed May 14, 2024
1 parent 4e0341a commit e2442eb
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 25 deletions.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
-r requirements.txt
coverage==7.3.2
psycopg2==2.9.9
mssql-django==1.5
ruff==0.3.7
8 changes: 6 additions & 2 deletions src/django_rest_aggregation/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ def aggregation(self, request):
aggregator = Aggregator(request, queryset, self.get_aggregation_name())
filtered_queryset = self.filter_aggregated_queryset(aggregator.get_aggregated_queryset())

context = {
"request": request,
"name": self.get_aggregation_name()
}
page = self.paginate_queryset(filtered_queryset)
if page is not None:
serializer = self.get_aggregation_serializer_class(page, many=True)
serializer = self.get_aggregation_serializer_class(page, many=True, context=context)
return self.get_paginated_response(serializer.data)
serializer = self.get_aggregation_serializer_class(filtered_queryset, many=True)
serializer = self.get_aggregation_serializer_class(filtered_queryset, many=True, context=context)
return Response(serializer.data)

def get_aggregation_serializer_class(self, *args, **kwargs):
Expand Down
60 changes: 44 additions & 16 deletions src/django_rest_aggregation/serializers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,45 @@
import decimal

from rest_framework import serializers
from rest_framework.serializers import BaseSerializer


class AggregationSerializer(BaseSerializer):
def to_representation(self, instance):
ret = {}
for field in instance.keys():
if isinstance(instance[field], int):
ret[field] = serializers.IntegerField(read_only=True).to_representation(instance[field])
elif isinstance(instance[field], float) or isinstance(instance[field], decimal.Decimal):
ret[field] = serializers.FloatField(read_only=True).to_representation(float(instance[field]))
else:
ret[field] = serializers.CharField(read_only=True).to_representation(instance[field])
return ret


class PassTroughField(serializers.Field):

def to_representation(self, value):
return value

def to_internal_value(self, data):
return data


class AggregationSerializer(serializers.Serializer):
"""
Serializer for aggregation requests.
Authored by Tim Streicher
"""

value = serializers.SerializerMethodField(method_name='serialize_value')

def serialize_value(self, obj):
return obj['value']

def __init__(self, *args, **kwargs):
# Instantiate the superclass normally
super(AggregationSerializer, self).__init__(*args, **kwargs)

fields = self.context['request'].query_params.get('group_by')
allowed = {self.context['name']}
if fields:
fields = fields.split(',')
allowed.update(fields)
else:
allowed.add('group')

# Drop any fields that are not specified in the `fields` argument except value.
existing = set(self.fields.keys())
for field_name in existing - allowed:
self.fields.pop(field_name)

for field_name in allowed:
field = self.fields.get(field_name)
# add field for unknown field_names
if field is None:
self.fields[field_name] = PassTroughField()
16 changes: 9 additions & 7 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime

from django.db import connection
from rest_framework.test import APITestCase

Expand Down Expand Up @@ -152,7 +154,7 @@ def test_minimum(self):
format="json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, [{"group": "all", "value": "2020-01-01"}])
self.assertEqual(response.data, [{"group": "all", "value": datetime.date(2020, 1, 1)}])

def test_maximum(self):
# Maximum of IntegerField
Expand Down Expand Up @@ -189,7 +191,7 @@ def test_maximum(self):
format="json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, [{"group": "all", "value": "2020-01-05"}])
self.assertEqual(response.data, [{"group": "all", "value": datetime.date(2020, 1, 5)}])


class TestGroupingAndAnnotations(APITestCase):
Expand Down Expand Up @@ -884,16 +886,16 @@ def test_ordering_all_field(self):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["results"], [])

def test_sqlite_exception(self):
# throws exception if sqlite version is less than
def test_pagiation_count(self):
response = self.client.get(
"/book/aggregation/",
"/customized_book/aggregation/",
{
"aggregation": "sum",
"aggregation_field": "price",
"value__gte": 1,
"ordering": "CustomizedValue",
"CustomizedValue__lte": 0,
},
format="json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, [{"group": "all", "value": 10.55}])
self.assertEqual(response.data["count"], 0)

0 comments on commit e2442eb

Please sign in to comment.