diff --git a/aggify/aggify.py b/aggify/aggify.py index 051df8e..338d2c6 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -4,7 +4,7 @@ from mongoengine import Document, EmbeddedDocument, fields from mongoengine.base import TopLevelDocumentMetaclass -from aggify.compiler import F, Match, Q, Operators # noqa keep +from aggify.compiler import F, Match, Q, Operators, Cond # noqa keep from aggify.exceptions import ( AggifyValueError, AnnotationError, @@ -70,10 +70,12 @@ def group(self, expression: str | None = "_id") -> "Aggify": return self @last_out_stage_check - def order_by(self, field: str) -> "Aggify": - self.pipelines.append( - {"$sort": {f'{field.replace("-", "")}': -1 if field.startswith("-") else 1}} - ) + def order_by(self, *fields: str | list[str]) -> "Aggify": + sort_dict = { + field.replace("-", ""): -1 if field.startswith("-") else 1 + for field in fields + } + self.pipelines.append({"$sort": sort_dict}) return self @last_out_stage_check @@ -82,7 +84,7 @@ def raw(self, raw_query: dict) -> "Aggify": return self @last_out_stage_check - def add_fields(self, fields: dict) -> "Aggify": # noqa + def add_fields(self, **fields) -> "Aggify": # noqa """ Generates a MongoDB addFields pipeline stage. @@ -99,6 +101,8 @@ def add_fields(self, fields: dict) -> "Aggify": # noqa add_fields_stage["$addFields"][field] = {"$literal": expression} elif isinstance(expression, F): add_fields_stage["$addFields"][field] = expression.to_dict() + elif isinstance(expression, Cond): + add_fields_stage["$addFields"][field] = dict(expression) else: raise AggifyValueError([str, F], type(expression)) diff --git a/aggify/compiler.py b/aggify/compiler.py index 6b2a97b..ba3e73b 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -247,7 +247,10 @@ def is_base_model_field(self, field) -> bool: and the base_model is not None, otherwise False. """ return self.base_model is not None and ( - isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa + isinstance( + self.base_model._fields.get(field), + (EmbeddedDocumentField, TopLevelDocumentMetaclass), + ) # noqa ) def compile(self, pipelines: list) -> dict[str, dict[str, list]]: @@ -263,7 +266,10 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]: raise InvalidOperator(key) field, operator, *_ = key.split("__") - if self.is_base_model_field(field) and operator not in Operators.ALL_OPERATORS: + if ( + self.is_base_model_field(field) + and operator not in Operators.ALL_OPERATORS + ): pipelines.append( Match({key.replace("__", ".", 1): value}, self.base_model).compile( [] @@ -274,6 +280,8 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]: if operator not in Operators.ALL_OPERATORS: raise InvalidOperator(operator) db_field = get_db_field(self.base_model, field) - match_query = Operators(match_query).compile_match(operator, value, db_field) + match_query = Operators(match_query).compile_match( + operator, value, db_field + ) return {"$match": match_query} diff --git a/aggify/utilty.py b/aggify/utilty.py index bb6e7a3..9b6c4e7 100644 --- a/aggify/utilty.py +++ b/aggify/utilty.py @@ -130,11 +130,12 @@ def check_field_exists(model: Type[Document], field: str) -> None: raise AlreadyExistsField(field=field) -def get_db_field(model: Type[Document], field: str) -> str: +def get_db_field(model: Type[Document], field: str, add_dollar_sign=False) -> str: """ Get the database field name for a given field in the model. Args: + add_dollar_sign: Add a "$" at the start of the field or not model (Document): The model containing the field. field (str): The name of the field. @@ -143,6 +144,7 @@ def get_db_field(model: Type[Document], field: str) -> str: """ try: db_field = model._fields.get(field).db_field # noqa - return field if db_field is None else db_field + db_field = field if db_field is None else db_field + return f"${db_field}" if add_dollar_sign else db_field except AttributeError: return field diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 532ecc1..0438b4c 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -124,12 +124,12 @@ def test_add_field_value_error(self): fields = { "new_field_1": True, } - aggify.add_fields(fields) + aggify.add_fields(**fields) def test_add_fields_string_literal(self): aggify = Aggify(BaseModel) fields = {"new_field_1": "some_string", "new_field_2": "another_string"} - add_fields_stage = aggify.add_fields(fields) + add_fields_stage = aggify.add_fields(**fields) expected_stage = { "$addFields": { @@ -146,7 +146,7 @@ def test_add_fields_with_f_expression(self): "new_field_1": F("existing_field") + 10, "new_field_2": F("field_a") * F("field_b"), } - add_fields_stage = aggify.add_fields(fields) + add_fields_stage = aggify.add_fields(**fields) expected_stage = { "$addFields": { diff --git a/tests/test_query.py b/tests/test_query.py index 0dd131f..3034669 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -60,15 +60,25 @@ class ParameterTestCase: compiled_query=Aggify(PostDocument).filter( caption__contains="hello", owner__deleted_at=None ), - expected_query=[{'$match': {'caption': {'$options': 'i', '$regex': '.*hello.*'}}}, - {'$lookup': {'as': 'owner', - 'foreignField': '_id', - 'from': 'account', - 'localField': 'owner_id'}}, - {'$unwind': {'includeArrayIndex': None, - 'path': '$owner', - 'preserveNullAndEmptyArrays': True}}, - {'$match': {'owner.deleted_at': None}}], + expected_query=[ + {"$match": {"caption": {"$options": "i", "$regex": ".*hello.*"}}}, + { + "$lookup": { + "as": "owner", + "foreignField": "_id", + "from": "account", + "localField": "owner_id", + } + }, + { + "$unwind": { + "includeArrayIndex": None, + "path": "$owner", + "preserveNullAndEmptyArrays": True, + } + }, + {"$match": {"owner.deleted_at": None}}, + ], ), ParameterTestCase( compiled_query=Aggify(PostDocument) @@ -116,7 +126,7 @@ class ParameterTestCase: ParameterTestCase( compiled_query=( Aggify(PostDocument).add_fields( - { + **{ "new_field_1": "some_string", "new_field_2": F("existing_field") + 10, "new_field_3": F("field_a") * F("field_b"), @@ -179,14 +189,28 @@ class ParameterTestCase: ) .filter(_posts1__ne=[]) ), - expected_query=[{'$lookup': {'as': '_posts1', - 'from': 'account', - 'let': {'owner': '$owner_id'}, - 'pipeline': [{'$match': {'$expr': {'$and': [{'$ne': ['$_id', - '$$owner']}, - {'$ne': ['$username', - 'seyed']}]}}}]}}, - {'$match': {'_posts1': {'$ne': []}}}], + expected_query=[ + { + "$lookup": { + "as": "_posts1", + "from": "account", + "let": {"owner": "$owner_id"}, + "pipeline": [ + { + "$match": { + "$expr": { + "$and": [ + {"$ne": ["$_id", "$$owner"]}, + {"$ne": ["$username", "seyed"]}, + ] + } + } + } + ], + } + }, + {"$match": {"_posts1": {"$ne": []}}}, + ], ), ParameterTestCase( compiled_query=( @@ -199,13 +223,20 @@ class ParameterTestCase: ) .filter(_posts2__ne=[]) ), - expected_query=[{'$lookup': {'as': '_posts2', - 'from': 'account', - 'let': {'caption': '$caption', 'owner': '$owner_id'}, - 'pipeline': [{'$match': {'$expr': {'$eq': ['$_id', '$$owner']}}}, - {'$match': {'$expr': {'$eq': ['$username', - '$$caption']}}}]}}, - {'$match': {'_posts2': {'$ne': []}}}], + expected_query=[ + { + "$lookup": { + "as": "_posts2", + "from": "account", + "let": {"caption": "$caption", "owner": "$owner_id"}, + "pipeline": [ + {"$match": {"$expr": {"$eq": ["$_id", "$$owner"]}}}, + {"$match": {"$expr": {"$eq": ["$username", "$$caption"]}}}, + ], + } + }, + {"$match": {"_posts2": {"$ne": []}}}, + ], ), ParameterTestCase( compiled_query=(Aggify(PostDocument).replace_root(embedded_field="stat")), @@ -255,32 +286,46 @@ class ParameterTestCase: ), ParameterTestCase( compiled_query=( - Aggify(PostDocument) - .lookup( + Aggify(PostDocument).lookup( AccountDocument, - local_field='owner', foreign_field='id', + local_field="owner", + foreign_field="id", as_name="_owner", ) ), - expected_query=[{'$lookup': {'as': '_owner', - 'foreignField': '_id', - 'from': 'account', - 'localField': 'owner_id'}}], + expected_query=[ + { + "$lookup": { + "as": "_owner", + "foreignField": "_id", + "from": "account", + "localField": "owner_id", + } + } + ], ), ParameterTestCase( compiled_query=( Aggify(PostDocument) .lookup( AccountDocument, - local_field='owner', foreign_field='id', + local_field="owner", + foreign_field="id", as_name="_owner1", - ).filter(_owner1__username='Aggify') + ) + .filter(_owner1__username="Aggify") ), - expected_query=[{'$lookup': {'as': '_owner1', - 'foreignField': '_id', - 'from': 'account', - 'localField': 'owner_id'}}, - {'$match': {'_owner1.username': 'Aggify'}}], + expected_query=[ + { + "$lookup": { + "as": "_owner1", + "foreignField": "_id", + "from": "account", + "localField": "owner_id", + } + }, + {"$match": {"_owner1.username": "Aggify"}}, + ], ), ]