Skip to content

Commit

Permalink
Merge pull request #14 from mohamadkhalaj/main
Browse files Browse the repository at this point in the history
Add new accumulators for annotation
  • Loading branch information
mohamadkhalaj authored Oct 30, 2023
2 parents e627c32 + efe7d37 commit 6da8122
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 80 deletions.
114 changes: 73 additions & 41 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Literal, Type
from typing import Any, Literal, Dict, Type

from mongoengine import Document, EmbeddedDocument, fields

Expand Down Expand Up @@ -54,8 +54,9 @@ def project(self, **kwargs: QueryParams) -> "Aggify":
return self

@last_out_stage_check
def group(self, key: str = "_id") -> "Aggify":
self.pipelines.append({"$group": {"_id": f"${key}"}})
def group(self, expression: str | None = "_id") -> "Aggify":
expression = f"${expression}" if expression else None
self.pipelines.append({"$group": {"_id": expression}})
return self

@last_out_stage_check
Expand Down Expand Up @@ -169,7 +170,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
raise ValueError(f"Invalid field: {split_query[0]}")
# This is a nested query.
if "document_type_obj" not in join_field.__dict__ or issubclass(
join_field.document_type, EmbeddedDocument
join_field.document_type, EmbeddedDocument
):
match = self.__match({key: value})
if (match.get("$match")) != {}:
Expand Down Expand Up @@ -213,7 +214,7 @@ def __getitem__(self, index: slice | int) -> "Aggify":

@staticmethod
def unwind(
path: str, preserve: bool = True
path: str, preserve: bool = True
) -> dict[
Literal["$unwind"],
dict[Literal["path", "preserveNullAndEmptyArrays"], str | bool],
Expand All @@ -239,49 +240,77 @@ def aggregate(self):
"""
return self.base_model.objects.aggregate(*self.pipelines) # type: ignore

def annotate(self, annotate_name, accumulator, f):
try:
if (stage := list(self.pipelines[-1].keys())[0]) != "$group":
raise AnnotationError(
f"Annotations apply only to $group, not to {stage}."
)
def annotate(self, annotate_name: str, accumulator: str,
f: str | dict | F | int) -> "Aggify":
"""
Annotate a MongoDB aggregation pipeline with a new field.
Usage: https://www.mongodb.com/docs/manual/reference/operator/aggregation/group/#accumulator-operator
Args:
annotate_name (str): The name of the new annotated field.
accumulator (str): The aggregation accumulator operator (e.g., "$sum", "$avg").
f (str | dict | F | int): The value for the annotated field.
Returns:
self.
Raises:
AnnotationError: If the pipeline is empty or if an invalid accumulator is provided.
Example:
annotate("totalSales", "sum", "sales")
"""

except IndexError as error:
raise AnnotationError(
"Annotations apply only to $group, you're pipeline is empty."
) from error

accumulator_dict = {
"sum": "$sum",
"avg": "$avg",
"first": "$first",
"last": "$last",
"max": "$max",
"min": "$min",
"push": "$push",
"addToSet": "$addToSet",
"stdDevPop": "$stdDevPop",
"stdDevSamp": "$stdDevSamp",
"merge": "$mergeObjects",
# Some of the accumulator fields might be false and should be checked.
aggregation_mapping: Dict[str, Type] = {
"sum": (fields.FloatField(), "$sum"),
"avg": (fields.FloatField(), "$avg"),
"stdDevPop": (fields.FloatField(), "$stdDevPop"),
"stdDevSamp": (fields.FloatField(), "$stdDevSamp"),
"push": (fields.ListField(), "$push"),
"addToSet": (fields.ListField(), "$addToSet"),
"count": (fields.IntField(), "$count"),
"first": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$first"),
"last": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$last"),
"max": (fields.DynamicField(), "$max"),
"accumulator": (fields.DynamicField(), "$accumulator"),
"min": (fields.DynamicField(), "$min"),
"median": (fields.DynamicField(), "$median"),
"mergeObjects": (fields.DictField(), "$mergeObjects"),
"top": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$top"),
"bottom": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$bottom"),
"topN": (fields.ListField(), "$topN"),
"bottomN": (fields.ListField(), "$bottomN"),
"firstN": (fields.ListField(), "$firstN"),
"lastN": (fields.ListField(), "$lastN"),
"maxN": (fields.ListField(), "$maxN"),
}

# Determine the data type based on the accumulator
if accumulator in ["sum", "avg", "stdDevPop", "stdDevSamp"]:
field_type = fields.FloatField()
elif accumulator in ["push", "addToSet"]:
field_type = fields.ListField()
else:
field_type = fields.StringField()
try:
stage = list(self.pipelines[-1].keys())[0]
if stage != "$group":
raise AnnotationError(f"Annotations apply only to $group, not to {stage}")
except IndexError:
raise AnnotationError("Annotations apply only to $group, your pipeline is empty")

try:
acc = accumulator_dict[accumulator]
field_type, acc = aggregation_mapping[accumulator]
except KeyError as error:
raise AnnotationError(f"Invalid accumulator: {accumulator}") from error

if isinstance(f, F):
value = f.to_dict()
else:
value = f"${f}"
if isinstance(f, str):
try:
self.get_model_field(self.base_model, f) # noqa
value = f"${f}"
except InvalidField:
value = f
else:
value = f

# Determine the data type based on the aggregation operator
self.pipelines[-1]["$group"] |= {annotate_name: {acc: value}}
self.base_model._fields[annotate_name] = field_type # noqa
return self
Expand All @@ -300,7 +329,7 @@ def __match(self, matches: dict[str, Any]):

@staticmethod
def __lookup(
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
) -> dict[str, dict[str, str]]:
"""
Generates a MongoDB lookup pipeline stage.
Expand Down Expand Up @@ -345,8 +374,9 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]:

return merged_pipeline

@last_out_stage_check
def lookup(
self, from_collection: Document, let: list[str], query: list[Q], as_name: str
self, from_collection: Document, let: list[str], query: list[Q], as_name: str
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.
Expand All @@ -363,8 +393,8 @@ def lookup(
check_fields_exist(self.base_model, let) # noqa

let_dict = {
field: f"${self.base_model._fields[field].db_field}" for field in let
} # noqa
field: f"${self.base_model._fields[field].db_field}" for field in let # noqa
}
from_collection = from_collection._meta.get("collection") # noqa

lookup_stages = []
Expand Down Expand Up @@ -443,6 +473,7 @@ def _replace_base(self, embedded_field) -> str:

return f"${model_field.db_field}"

@last_out_stage_check
def replace_root(self, *, embedded_field: str, merge: dict | None = None) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down Expand Up @@ -480,6 +511,7 @@ def replace_root(self, *, embedded_field: str, merge: dict | None = None) -> "Ag

return self

@last_out_stage_check
def replace_with(self, *, embedded_field: str, merge: dict | None = None) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down
Loading

0 comments on commit 6da8122

Please sign in to comment.