diff --git a/src/django_rest_aggregation/aggregator.py b/src/django_rest_aggregation/aggregator.py index 2a458fc..91431c8 100644 --- a/src/django_rest_aggregation/aggregator.py +++ b/src/django_rest_aggregation/aggregator.py @@ -1,5 +1,6 @@ from django.core.exceptions import FieldDoesNotExist from django.db import models +from django.db.models import Case, When, Value from rest_framework.exceptions import ValidationError from rest_framework.request import Request @@ -83,7 +84,9 @@ def get_aggregated_queryset(self): self.validate_params() if (group_by := self.params.get("group_by", ["group"])) == ["group"]: - self.queryset = self.queryset.annotate(group=models.Value("all", output_field=models.CharField())) + self.queryset = self.queryset.annotate(group=Case( + When(pk__isnull=False, then=Value("all")), + default=Value("test"))) return self.queryset.values(*group_by).annotate(**get_annotation(self.params, self.aggregation_name)) diff --git a/src/django_rest_aggregation/mixins.py b/src/django_rest_aggregation/mixins.py index 4ada77f..6d4e157 100644 --- a/src/django_rest_aggregation/mixins.py +++ b/src/django_rest_aggregation/mixins.py @@ -10,7 +10,7 @@ class AggregationMixin: @action(methods=["get"], detail=False, url_path="aggregation", url_name="aggregation") def aggregation(self, request): - queryset = self.get_queryset() + queryset = self.get_queryset().order_by() if DjangoFilterBackend in self.filter_backends: queryset = DjangoFilterBackend().filter_queryset(request, queryset, self) @@ -43,12 +43,12 @@ def get_aggregation_name(self): def filter_aggregated_queryset(self, queryset): ordering_fields = getattr(self, "ordering_fields", []) - valid_fields = queryset[0].keys() + valid_fields = queryset[0].keys() if queryset.exists() else [] if ordering_fields == "__all__": ordering_fields = valid_fields else: - ordering_fields = list(set(ordering_fields).intersection(set(queryset[0].keys()))) + ordering_fields = list(set(ordering_fields).intersection(set(valid_fields))) if (fields := getattr(self, "aggregated_filterset_fields", None)) is not None: ValueFilter.set_filter_fields(fields, self.get_aggregation_name())