Skip to content
10 changes: 8 additions & 2 deletions src/django_mysql/models/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from django.db.models import Aggregate, CharField

from django_mysql.models.fields import ListCharField, SetCharField


class BitAnd(Aggregate):
function = "BIT_AND"
Expand All @@ -24,8 +26,12 @@ def __init__(
):

if "output_field" not in extra:
# This can/will be improved to SetTextField or ListTextField
extra["output_field"] = CharField()
if separator is not None:
extra["output_field"] = CharField()
elif distinct:
extra["output_field"] = SetCharField(CharField())
else:
extra["output_field"] = ListCharField(CharField())

super().__init__(expression, **extra)

Expand Down
15 changes: 9 additions & 6 deletions tests/testapp/test_aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ def setUp(self):

def test_basic_aggregate_ids(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id"))
concatted_ids = ",".join(self.str_tutee_ids)
assert out == {"tids": concatted_ids}
assert out == {"tids": self.str_tutee_ids}

def test_distinct_aggregate_ids(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id", distinct=True))
assert out == {"tids": set(self.str_tutee_ids)}

def test_basic_annotate_ids(self):
concat = GroupConcat("tutees__id")
Expand Down Expand Up @@ -104,14 +107,14 @@ def test_separator_big(self):
def test_expression(self):
concat = GroupConcat(F("id") + 1)
out = self.shakes.tutees.aggregate(tids=concat)
concatted_ids = ",".join([str(self.jk.id + 1), str(self.grisham.id + 1)])
concatted_ids = [str(self.jk.id + 1), str(self.grisham.id + 1)]
assert out == {"tids": concatted_ids}

def test_application_order(self):
out = Author.objects.exclude(id=self.shakes.id).aggregate(
tids=GroupConcat("tutor_id", distinct=True)
)
assert out == {"tids": str(self.shakes.id)}
assert out == {"tids": {str(self.shakes.id)}}

@override_mysql_variables(SQL_MODE="ANSI")
def test_separator_ansi_mode(self):
Expand All @@ -127,11 +130,11 @@ def test_ordering_invalid(self):

def test_ordering_asc(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id", ordering="asc"))
assert out == {"tids": ",".join(self.str_tutee_ids)}
assert out == {"tids": self.str_tutee_ids}

def test_ordering_desc(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id", ordering="desc"))
assert out == {"tids": ",".join(reversed(self.str_tutee_ids))}
assert out == {"tids": list(reversed(self.str_tutee_ids))}

def test_separator_ordering(self):
concat = GroupConcat("id", separator=":", ordering="asc")
Expand Down