Skip to content

Commit

Permalink
Merge pull request #7 from mahdihaghverdi/more-tests
Browse files Browse the repository at this point in the history
Add _more_ tests for `field`, `group`, `annotate` and `addField` methods and improve coverage by 7%
  • Loading branch information
seyed-dev authored Oct 28, 2023
2 parents 19dddcc + 39d7b66 commit 44a357b
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 8 deletions.
19 changes: 11 additions & 8 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def unwind(path, preserve=True):
}

def to_aggregate(self):
"""
Builds the pipelines list based on the query parameters.
"""
"""Builds the pipelines list based on the query parameters."""
skip_list = []
for key, value in self.q.items():
if key in skip_list:
Expand Down Expand Up @@ -227,8 +225,11 @@ def group(self, key="_id"):
return self

def annotate(self, annotate_name, accumulator, f):
if (stage := list(self.pipelines[-1].keys())[0]) != "$group":
raise ValueError(f"Annotations apply only to $group, not to {stage}.")
try:
if (stage := list(self.pipelines[-1].keys())[0]) != "$group":
raise ValueError(f"Annotations apply only to $group, not to {stage}.")
except IndexError:
raise ValueError(f"Annotations apply only to $group, you're pipeline is empty.")

accumulator_dict = {
"sum": "$sum",
Expand All @@ -243,15 +244,17 @@ def annotate(self, annotate_name, accumulator, f):
"stdDevSamp": "$stdDevSamp" # noqa
}

acc = accumulator_dict.get(accumulator, None)
if not acc:
raise ValueError(f"Invalid accumulator: {accumulator}")
try:
acc = accumulator_dict[accumulator]
except KeyError:
raise ValueError(f"Invalid accumulator: {accumulator}") from None

if isinstance(f, F):
value = f.to_dict()
else:
value = f"${f}"
self.pipelines[-1]['$group'] |= {annotate_name: {acc: value}}
return self

def order_by(self, field):
self.pipelines.append({'$sort': {
Expand Down
104 changes: 104 additions & 0 deletions tests/test_aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class BaseModel(Document):
'abstract': True
}


# This defines a base document model for MongoDB using MongoEngine, with 'name' and 'age' fields.
# The 'allow_inheritance' and 'abstract' options ensure it's used as a base class for other documents.

Expand Down Expand Up @@ -132,6 +133,16 @@ def test_filter_with_not_operator(self):
assert len(aggify.pipelines) == 1
assert aggify.pipelines[0]["$match"]["$not"][0]["name"] == "John"

def test_add_field_value_error(self):
with pytest.raises(ValueError) as err:
aggify = Aggify(BaseModel)
fields = {
"new_field_1": True,
}
aggify.addFields(fields)

assert 'invalid field expression' in err.__str__().lower()

def test_add_fields_string_literal(self):
aggify = Aggify(BaseModel)
fields = {
Expand Down Expand Up @@ -164,3 +175,96 @@ def test_add_fields_with_f_expression(self):
}
}
assert add_fields_stage.pipelines[0] == expected_stage

def test_filter_value_error(self):
with pytest.raises(ValueError) as err:
Aggify(BaseModel).filter(arg='Hi')

assert 'invalid' in err.__str__().lower()

def test_group(self):
aggify = Aggify(BaseModel)
thing = aggify.group('name')
assert len(thing.pipelines) == 1
assert thing.pipelines[-1]['$group'] == {'_id': '$name'}

def test_annotate_empty_pipeline_value_error(self):
with pytest.raises(ValueError) as err:
Aggify(BaseModel).annotate('size', 'sum', None)

assert "you're pipeline is empty" in err.__str__().lower()

def test_annotate_not_group_value_error(self):
with pytest.raises(ValueError) as err:
Aggify(BaseModel)[1].annotate('size', 'sum', None)

assert 'not to $limit' in err.__str__().lower()

def test_annotate_invalid_accumulator(self):
with pytest.raises(ValueError) as err:
Aggify(BaseModel).group('name').annotate('size', 'mahdi', None)

assert 'invalid accumulator' in err.__str__().lower()

@pytest.mark.parametrize(
'accumulator',
(
'sum',
"avg",
"first",
"last",
"max",
"min",
"push",
"addToSet",
"stdDevPop",
"stdDevSamp",
)
)
def test_annotate_with_raw_f(self, accumulator):
aggify = Aggify(BaseModel)
thing = aggify.group().annotate('price', accumulator, F('price'))
assert len(thing.pipelines) == 1
assert thing.pipelines[-1]['$group']['price'] == {f'${accumulator}': '$price'}

@pytest.mark.parametrize(
'accumulator',
(
'sum',
"avg",
"first",
"last",
"max",
"min",
"push",
"addToSet",
"stdDevPop",
"stdDevSamp",
)
)
def test_annotate_with_f(self, accumulator):
aggify = Aggify(BaseModel)
thing = aggify.group().annotate('price', accumulator, F('price') * 10)
assert len(thing.pipelines) == 1
assert thing.pipelines[-1]['$group']['price'] == {f'${accumulator}': {'$multiply': ['$price', 10]}}

@pytest.mark.parametrize(
'accumulator',
(
'sum',
"avg",
"first",
"last",
"max",
"min",
"push",
"addToSet",
"stdDevPop",
"stdDevSamp",
)
)
def test_annotate_raw_value(self, accumulator):
aggify = Aggify(BaseModel)
thing = aggify.group().annotate('some_name', accumulator, 'some_value')
assert len(thing.pipelines) == 1
assert thing.pipelines[-1]['$group']['some_name'] == {f'${accumulator}': '$some_value'}

0 comments on commit 44a357b

Please sign in to comment.