Skip to content

Commit

Permalink
Support python 3.8+
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamadkhalaj committed Nov 1, 2023
1 parent fab2ea3 commit e43c0e4
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 53 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
- name: Set up Python 3.8
uses: actions/setup-python@v3
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
75 changes: 39 additions & 36 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Dict, Type
from typing import Any, Dict, Type, Union, List

from mongoengine import Document, EmbeddedDocument, fields
from mongoengine.base import TopLevelDocumentMetaclass
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self, base_model: Type[Document]):
self.base_model = type(
"Aggify_base_model", base_model.__bases__, dict(base_model.__dict__)
)
self.pipelines: list[dict[str, dict | Any]] = []
self.pipelines: List[Dict[str, Union[dict, Any]]] = []
self.start = None
self.stop = None
self.q = None
Expand Down Expand Up @@ -112,13 +112,13 @@ def project(self, **kwargs: QueryParams) -> "Aggify":
return self

@last_out_stage_check
def group(self, expression: str | None = "_id") -> "Aggify":
def group(self, expression: Union[str, None] = "_id") -> "Aggify":
expression = f"${expression}" if expression else None
self.pipelines.append({"$group": {"_id": expression}})
return self

@last_out_stage_check
def order_by(self, *fields: str | list[str]) -> "Aggify":
def order_by(self, *fields: Union[str, List[str]]) -> "Aggify":
sort_dict = {
field.replace("-", ""): -1 if field.startswith("-") else 1
for field in fields
Expand All @@ -137,7 +137,7 @@ def add_fields(self, **_fields) -> "Aggify": # noqa
Generates a MongoDB addFields pipeline stage.
Args:
fields: A dictionary of field expressions and values.
_fields: A dictionary of field expressions and values.
Returns:
A MongoDB add_fields pipeline stage.
Expand All @@ -161,7 +161,7 @@ def add_fields(self, **_fields) -> "Aggify": # noqa
return self

@last_out_stage_check
def filter(self, arg: Q | None = None, **kwargs: QueryParams) -> "Aggify":
def filter(self, arg: Union[Q, None] = None, **kwargs: QueryParams) -> "Aggify":
"""
# TODO: missing docs
"""
Expand All @@ -178,7 +178,7 @@ def filter(self, arg: Q | None = None, **kwargs: QueryParams) -> "Aggify":

return self

def out(self, coll: str, db: str | None = None) -> "Aggify":
def out(self, coll: str, db: Union[str, None] = None) -> "Aggify":
"""Write the documents returned by the aggregation pipeline into specified collection.
Starting in MongoDB 4.4, you can specify the output database.
Expand Down Expand Up @@ -219,7 +219,7 @@ def out(self, coll: str, db: str | None = None) -> "Aggify":
self.pipelines.append(stage)
return self

def __to_aggregate(self, query: dict[str, Any]) -> None:
def __to_aggregate(self, query: Dict[str, Any]) -> None:
"""
Builds the pipelines list based on the query parameters.
"""
Expand All @@ -242,7 +242,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
or "document_type_obj"
not in join_field.__dict__ # Check whether this field is a join field or not.
or issubclass(
join_field.document_type, EmbeddedDocument
join_field.document_type, EmbeddedDocument # noqa
) # Check whether this field is embedded field or not
or len(split_query) == 1
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
Expand Down Expand Up @@ -278,7 +278,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
self.pipelines.extend([{"$match": match} for match in matches])

@last_out_stage_check
def __getitem__(self, index: slice | int) -> "Aggify":
def __getitem__(self, index: Union[slice, int]) -> "Aggify":
"""
# TODO: missing docs
"""
Expand All @@ -292,7 +292,10 @@ def __getitem__(self, index: slice | int) -> "Aggify":

@last_out_stage_check
def unwind(
self, path: str, include_index_array: str | None = None, preserve: bool = False
self,
path: str,
include_index_array: Union[str, None] = None,
preserve: bool = False,
) -> "Aggify":
"""Generates a MongoDB unwind pipeline stage.
Expand Down Expand Up @@ -350,7 +353,7 @@ def aggregate(self):
return self.base_model.objects.aggregate(*self.pipelines) # type: ignore

def annotate(
self, annotate_name: str, accumulator: str, f: str | dict | F | int
self, annotate_name: str, accumulator: str, f: Union[Union[str, Dict], F, int]
) -> "Aggify":
"""
Annotate a MongoDB aggregation pipeline with a new field.
Expand All @@ -359,7 +362,7 @@ def annotate(
Args:
annotate_name (str): The name of the new annotated field.
accumulator (str): The aggregation accumulator operator (e.g., "$sum", "$avg").
f (str | dict | F | int): The value for the annotated field.
f (Union[str, Dict] | F | int): The value for the annotated field.
Returns:
self.
Expand Down Expand Up @@ -420,19 +423,19 @@ def annotate(
else:
if isinstance(f, str):
try:
self.get_model_field(self.base_model, f)
self.get_model_field(self.base_model, f) # noqa
value = f"${f}"
except InvalidField:
value = f
else:
value = f

# Determine the data type based on the aggregation operator
self.pipelines[-1]["$group"] |= {annotate_name: {acc: value}}
self.pipelines[-1]["$group"].update({annotate_name: {acc: value}})
self.base_model._fields[annotate_name] = field_type # noqa
return self

def __match(self, matches: dict[str, Any]):
def __match(self, matches: Dict[str, Any]):
"""
Generates a MongoDB match pipeline stage.
Expand All @@ -442,12 +445,12 @@ def __match(self, matches: dict[str, Any]):
Returns:
A MongoDB match pipeline stage.
"""
return Match(matches, self.base_model).compile(self.pipelines)
return Match(matches, self.base_model).compile(self.pipelines) # noqa

@staticmethod
def __lookup(
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
) -> dict[str, dict[str, str]]:
) -> Dict[str, Dict[str, str]]:
"""
Generates a MongoDB lookup pipeline stage.
Expand All @@ -469,7 +472,7 @@ def __lookup(
}
}

def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]:
def __combine_sequential_matches(self) -> List[Dict[str, Union[dict, Any]]]:
merged_pipeline = []
match_stage = None

Expand All @@ -496,28 +499,28 @@ def lookup(
self,
from_collection: Document,
as_name: str,
query: list[Q] | Q | None = None,
let: list[str] | None = None,
local_field: str | None = None,
foreign_field: str | None = None,
query: Union[List[Q], Union[Q, None]] = None,
let: Union[List[str], None] = None,
local_field: Union[str, None] = None,
foreign_field: Union[str, None] = None,
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.
Args:
from_collection (Document): The document representing the collection to perform the lookup on.
as_name (str): The name of the new field to create.
query (list[Q] | Q | None, optional): List of desired queries with Q function or a single query.
let (list[str] | None, optional): The local field(s) to join on. If provided, localField and foreignField are not used.
local_field (str | None, optional): The local field to join on when let is not provided.
foreign_field (str | None, optional): The foreign field to join on when let is not provided.
query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query.
let (Union[List[str], None], optional): The local field(s) to join on. If provided, localField and foreignField are not used.
local_field (Union[str, None], optional): The local field to join on when let is not provided.
foreign_field (Union[str, None], optional): The foreign field to join on when let is not provided.
Returns:
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
"""

lookup_stages = []
check_field_exists(self.base_model, as_name)
check_field_exists(self.base_model, as_name) # noqa
from_collection_name = from_collection._meta.get("collection") # noqa

if not let and not (local_field and foreign_field):
Expand All @@ -532,8 +535,8 @@ def lookup(
"from": from_collection_name,
"localField": get_db_field(self.base_model, local_field), # noqa
"foreignField": get_db_field(
from_collection, foreign_field
), # noqa
from_collection, foreign_field # noqa
),
"as": as_name,
}
}
Expand All @@ -542,7 +545,7 @@ def lookup(
raise InvalidArgument(expected_list=["query"])
check_fields_exist(self.base_model, let) # noqa
let_dict = {
field: f"${get_db_field(self.base_model, field)}"
field: f"${get_db_field(self.base_model, field)}" # noqa
for field in let # noqa
}
for q in query:
Expand Down Expand Up @@ -612,7 +615,7 @@ def _replace_base(self, embedded_field) -> str:
Raises:
InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type.
"""
model_field = self.get_model_field(self.base_model, embedded_field)
model_field = self.get_model_field(self.base_model, embedded_field) # noqa

if not hasattr(model_field, "document_type") or not issubclass(
model_field.document_type, EmbeddedDocument
Expand All @@ -623,14 +626,14 @@ def _replace_base(self, embedded_field) -> str:

@last_out_stage_check
def replace_root(
self, *, embedded_field: str, merge: dict | None = None
self, *, embedded_field: str, merge: Union[Dict, None] = None
) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Args:
embedded_field (str): The name of the embedded field to use as the new root.
merge (dict | None, optional): A dictionary for merging with the new root. Default is None.
merge (Union[Dict, None], optional): A dictionary for merging with the new root. Default is None.
Returns:
Aggify: The modified Aggify instance.
Expand All @@ -651,14 +654,14 @@ def replace_root(

@last_out_stage_check
def replace_with(
self, *, embedded_field: str, merge: dict | None = None
self, *, embedded_field: str, merge: Union[Dict, None] = None
) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Args:
embedded_field (str): The name of the embedded field to use as the new root.
merge (dict | None, optional): A dictionary for merging with the new root. Default is None.
merge (Union[Dict, None], optional): A dictionary for merging with the new root. Default is None.
Returns:
Aggify: The modified Aggify instance.
Expand Down
14 changes: 8 additions & 6 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Any, Type, Union, Dict

from mongoengine import Document, EmbeddedDocumentField
from mongoengine.base import TopLevelDocumentMetaclass
Expand Down Expand Up @@ -34,7 +34,7 @@ class Operators:
**COMPARISON_OPERATORS,
}

def __init__(self, match_query: dict[str, Any]):
def __init__(self, match_query: Dict[str, Any]):
self.match_query = match_query

def compile_match(self, operator: str, value, field: str):
Expand Down Expand Up @@ -76,7 +76,7 @@ def compile_match(self, operator: str, value, field: str):


class Q:
def __init__(self, pipeline: list | None = None, **conditions):
def __init__(self, pipeline: Union[list, None] = None, **conditions):
pipeline = pipeline or []
self.conditions: dict[str, list] = (
Match(
Expand Down Expand Up @@ -113,7 +113,7 @@ def __invert__(self):


class F:
def __init__(self, field: str | dict[str, list]):
def __init__(self, field: Union[str, Dict[str, list]]):
if isinstance(field, str):
self.field = f"${field.replace('__', '.')}"
else:
Expand Down Expand Up @@ -224,7 +224,9 @@ def __iter__(self):


class Match:
def __init__(self, matches: dict[str, Any], base_model: Type[Document] | None):
def __init__(
self, matches: Dict[str, Any], base_model: Union[Type[Document], None]
):
self.matches = matches
self.base_model = base_model

Expand Down Expand Up @@ -254,7 +256,7 @@ def is_base_model_field(self, field) -> bool:
)
)

def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
def compile(self, pipelines: list) -> Dict[str, Dict[str, list]]:
match_query = {}
for key, value in self.matches.items():
if "__" not in key:
Expand Down
4 changes: 2 additions & 2 deletions aggify/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type
from typing import Type, List


class AggifyBaseException(Exception):
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self, stage):


class AggifyValueError(AggifyBaseException):
def __init__(self, expected_list: list[Type], result: Type):
def __init__(self, expected_list: List[Type], result: Type):
self.message = (
f"Input is not correctly passed, expected either of {[expected for expected in expected_list]}"
f"but got {result}"
Expand Down
4 changes: 3 additions & 1 deletion aggify/types.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
QueryParams = int | None | str | bool | float | dict
from typing import Union, Dict

QueryParams = Union[int, None, str, bool, float, Dict]
12 changes: 7 additions & 5 deletions aggify/utilty.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Any, Type, Union, List, Dict

from mongoengine import Document

Expand All @@ -16,7 +16,7 @@ def int_to_slice(final_index: int) -> slice:
return slice(0, final_index)


def to_mongo_positive_index(index: int | slice) -> slice:
def to_mongo_positive_index(index: Union[int, slice]) -> slice:
if isinstance(index, int):
if index < 0:
raise MongoIndexError
Expand All @@ -33,7 +33,7 @@ def to_mongo_positive_index(index: int | slice) -> slice:
return index


def check_fields_exist(model: Document, fields_to_check: list[str]) -> None:
def check_fields_exist(model: Document, fields_to_check: List[str]) -> None:
"""
Check if the specified fields exist in a model's fields.
Expand Down Expand Up @@ -79,8 +79,10 @@ def replace_values_recursive(obj, replacements):


def convert_match_query(
d: dict,
) -> dict[Any, list[str | Any] | dict] | list[dict] | dict:
d: Dict,
) -> Union[Dict[Any, Union[List[Union[str, Any]], Dict]], List[Dict], Dict]:
pass

"""
Recursively transform a dictionary to modify the structure of '$eq' and '$ne' operators.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test__getitem__slice(self):

def test__getitem__value_error(self):
with pytest.raises(AggifyValueError) as err:
Aggify(BaseModel)["hello"] # type: ignore
Aggify(BaseModel)["hello"] # type: ignore # noqa

assert "str" in err.__str__(), "wrong type was not detected"

Expand Down

0 comments on commit e43c0e4

Please sign in to comment.