From e2442ebdc1aed7aa6dba2d6e6c2b2c94f28be9cf Mon Sep 17 00:00:00 2001 From: Simon Legtenborg Date: Tue, 14 May 2024 16:55:27 +0200 Subject: [PATCH] feat: rebuild aggregation serializer --- requirements-dev.txt | 1 + src/django_rest_aggregation/mixins.py | 8 ++- src/django_rest_aggregation/serializers.py | 60 ++++++++++++++++------ tests/test_aggregation.py | 16 +++--- 4 files changed, 60 insertions(+), 25 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0f06113..8c66841 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ -r requirements.txt coverage==7.3.2 psycopg2==2.9.9 +mssql-django==1.5 ruff==0.3.7 \ No newline at end of file diff --git a/src/django_rest_aggregation/mixins.py b/src/django_rest_aggregation/mixins.py index 0081635..4ada77f 100644 --- a/src/django_rest_aggregation/mixins.py +++ b/src/django_rest_aggregation/mixins.py @@ -18,11 +18,15 @@ def aggregation(self, request): aggregator = Aggregator(request, queryset, self.get_aggregation_name()) filtered_queryset = self.filter_aggregated_queryset(aggregator.get_aggregated_queryset()) + context = { + "request": request, + "name": self.get_aggregation_name() + } page = self.paginate_queryset(filtered_queryset) if page is not None: - serializer = self.get_aggregation_serializer_class(page, many=True) + serializer = self.get_aggregation_serializer_class(page, many=True, context=context) return self.get_paginated_response(serializer.data) - serializer = self.get_aggregation_serializer_class(filtered_queryset, many=True) + serializer = self.get_aggregation_serializer_class(filtered_queryset, many=True, context=context) return Response(serializer.data) def get_aggregation_serializer_class(self, *args, **kwargs): diff --git a/src/django_rest_aggregation/serializers.py b/src/django_rest_aggregation/serializers.py index 4d8dfc5..a22f239 100644 --- a/src/django_rest_aggregation/serializers.py +++ b/src/django_rest_aggregation/serializers.py @@ -1,17 +1,45 @@ -import decimal - from rest_framework import serializers -from rest_framework.serializers import BaseSerializer - - -class AggregationSerializer(BaseSerializer): - def to_representation(self, instance): - ret = {} - for field in instance.keys(): - if isinstance(instance[field], int): - ret[field] = serializers.IntegerField(read_only=True).to_representation(instance[field]) - elif isinstance(instance[field], float) or isinstance(instance[field], decimal.Decimal): - ret[field] = serializers.FloatField(read_only=True).to_representation(float(instance[field])) - else: - ret[field] = serializers.CharField(read_only=True).to_representation(instance[field]) - return ret + + +class PassTroughField(serializers.Field): + + def to_representation(self, value): + return value + + def to_internal_value(self, data): + return data + + +class AggregationSerializer(serializers.Serializer): + """ + Serializer for aggregation requests. + Authored by Tim Streicher + """ + + value = serializers.SerializerMethodField(method_name='serialize_value') + + def serialize_value(self, obj): + return obj['value'] + + def __init__(self, *args, **kwargs): + # Instantiate the superclass normally + super(AggregationSerializer, self).__init__(*args, **kwargs) + + fields = self.context['request'].query_params.get('group_by') + allowed = {self.context['name']} + if fields: + fields = fields.split(',') + allowed.update(fields) + else: + allowed.add('group') + + # Drop any fields that are not specified in the `fields` argument except value. + existing = set(self.fields.keys()) + for field_name in existing - allowed: + self.fields.pop(field_name) + + for field_name in allowed: + field = self.fields.get(field_name) + # add field for unknown field_names + if field is None: + self.fields[field_name] = PassTroughField() diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index e179e90..45f4ddb 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -1,3 +1,5 @@ +import datetime + from django.db import connection from rest_framework.test import APITestCase @@ -152,7 +154,7 @@ def test_minimum(self): format="json", ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, [{"group": "all", "value": "2020-01-01"}]) + self.assertEqual(response.data, [{"group": "all", "value": datetime.date(2020, 1, 1)}]) def test_maximum(self): # Maximum of IntegerField @@ -189,7 +191,7 @@ def test_maximum(self): format="json", ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, [{"group": "all", "value": "2020-01-05"}]) + self.assertEqual(response.data, [{"group": "all", "value": datetime.date(2020, 1, 5)}]) class TestGroupingAndAnnotations(APITestCase): @@ -884,16 +886,16 @@ def test_ordering_all_field(self): self.assertEqual(response.status_code, 200) self.assertEqual(response.data["results"], []) - def test_sqlite_exception(self): - # throws exception if sqlite version is less than + def test_pagiation_count(self): response = self.client.get( - "/book/aggregation/", + "/customized_book/aggregation/", { "aggregation": "sum", "aggregation_field": "price", - "value__gte": 1, + "ordering": "CustomizedValue", + "CustomizedValue__lte": 0, }, format="json", ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, [{"group": "all", "value": 10.55}]) + self.assertEqual(response.data["count"], 0) \ No newline at end of file