Skip to content

Commit

Permalink
Update accumulator parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamadkhalaj committed Oct 30, 2023
1 parent b9fe23e commit efe7d37
Showing 1 changed file with 141 additions and 39 deletions.
180 changes: 141 additions & 39 deletions tests/test_aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def test_complex_conditional_expression_in_projection(self):
"$gt": ["$age", 30]
}
assert (
aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["then"] == "Adult"
aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["then"] == "Adult"
)
assert (
aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["else"] == "Child"
aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["else"] == "Child"
)

# Test filtering using not operator
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_annotate_empty_pipeline_value_error(self):
with pytest.raises(AnnotationError) as err:
Aggify(BaseModel).annotate("size", "sum", None)

assert "you're pipeline is empty" in err.__str__().lower()
assert "your pipeline is empty" in err.__str__().lower()

def test_annotate_not_group_value_error(self):
with pytest.raises(AnnotationError) as err:
Expand All @@ -185,16 +185,27 @@ def test_annotate_invalid_accumulator(self):
@pytest.mark.parametrize(
"accumulator",
(
"sum",
"avg",
"first",
"last",
"max",
"min",
"push",
"addToSet",
"stdDevPop",
"stdDevSamp",
"sum",
"avg",
"stdDevPop",
"stdDevSamp",
"push",
"addToSet",
"count",
"first",
"last",
"max",
"accumulator",
"min",
"median",
"mergeObjects",
"top",
"bottom",
"topN",
"bottomN",
"firstN",
"lastN",
"maxN",
),
)
def test_annotate_with_raw_f(self, accumulator):
Expand All @@ -206,16 +217,27 @@ def test_annotate_with_raw_f(self, accumulator):
@pytest.mark.parametrize(
"accumulator",
(
"sum",
"avg",
"first",
"last",
"max",
"min",
"push",
"addToSet",
"stdDevPop",
"stdDevSamp",
"sum",
"avg",
"stdDevPop",
"stdDevSamp",
"push",
"addToSet",
"count",
"first",
"last",
"max",
"accumulator",
"min",
"median",
"mergeObjects",
"top",
"bottom",
"topN",
"bottomN",
"firstN",
"lastN",
"maxN",
),
)
def test_annotate_with_f(self, accumulator):
Expand All @@ -229,25 +251,105 @@ def test_annotate_with_f(self, accumulator):
@pytest.mark.parametrize(
"accumulator",
(
"sum",
"avg",
"first",
"last",
"max",
"min",
"push",
"addToSet",
"stdDevPop",
"stdDevSamp",
"sum",
"avg",
"stdDevPop",
"stdDevSamp",
"push",
"addToSet",
"count",
"first",
"last",
"max",
"accumulator",
"min",
"median",
"mergeObjects",
"top",
"bottom",
"topN",
"bottomN",
"firstN",
"lastN",
"maxN",
),
)
def test_annotate_raw_value(self, accumulator):
aggify = Aggify(BaseModel)
thing = aggify.group().annotate("some_name", accumulator, "name")
assert len(thing.pipelines) == 1
assert thing.pipelines[-1]["$group"]["some_name"] == {
f"${accumulator}": "$name"
}

@pytest.mark.parametrize(
"accumulator",
(
"sum",
"avg",
"stdDevPop",
"stdDevSamp",
"push",
"addToSet",
"count",
"first",
"last",
"max",
"accumulator",
"min",
"median",
"mergeObjects",
"top",
"bottom",
"topN",
"bottomN",
"firstN",
"lastN",
"maxN",
),
)
def test_annotate_raw_value_not_model_field(self, accumulator):
aggify = Aggify(BaseModel)
thing = aggify.group().annotate("some_name", accumulator, "some_value")
assert len(thing.pipelines) == 1
assert thing.pipelines[-1]["$group"]["some_name"] == {
f"${accumulator}": "some_value"
}

@pytest.mark.parametrize(
"accumulator",
(
"sum",
"avg",
"stdDevPop",
"stdDevSamp",
"push",
"addToSet",
"count",
"first",
"last",
"max",
"accumulator",
"min",
"median",
"mergeObjects",
"top",
"bottom",
"topN",
"bottomN",
"firstN",
"lastN",
"maxN",
),
)
def test_annotate_add_annotated_field_to_base_model(self, accumulator):
aggify = Aggify(BaseModel)
thing = aggify.group().annotate("some_name", accumulator, "some_value")
assert len(thing.pipelines) == 1
assert thing.pipelines[-1]["$group"]["some_name"] == {
f"${accumulator}": "$some_value"
f"${accumulator}": "some_value"
}
assert aggify.filter(some_name=123).pipelines[-1] == {"$match": {"some_name": 123}}

def test_out_with_project_stage_error(self):
with pytest.raises(OutStageError):
Expand All @@ -256,11 +358,11 @@ def test_out_with_project_stage_error(self):
@pytest.mark.parametrize(
("method", "args"),
(
("group", ("_id",)),
("order_by", ("field",)),
("raw", ({"$query": "query"},)),
("add_fields", ({"$field": "value"},)),
("filter", (Q(age=20),)),
("group", ("_id",)),
("order_by", ("field",)),
("raw", ({"$query": "query"},)),
("add_fields", ({"$field": "value"},)),
("filter", (Q(age=20),)),
),
)
def test_out_stage_error(self, method, args):
Expand Down

0 comments on commit efe7d37

Please sign in to comment.