From ad980fba22c6f6310ed85f0737402d6c8cde63e6 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Tue, 31 Oct 2023 11:59:36 +0330 Subject: [PATCH 1/6] Improve lookup --- aggify/aggify.py | 132 ++++++++++++++++++++++++++------------------- aggify/compiler.py | 12 +++-- 2 files changed, 83 insertions(+), 61 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 35fa41b..64b735c 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -1,7 +1,8 @@ import functools -from typing import Any, Literal, Dict, Type +from typing import Any, Dict, Type from mongoengine import Document, EmbeddedDocument, fields +from mongoengine.base import TopLevelDocumentMetaclass from aggify.compiler import F, Match, Q, Operators # noqa keep from aggify.exceptions import ( @@ -9,14 +10,14 @@ AnnotationError, InvalidField, InvalidEmbeddedField, - OutStageError, + OutStageError, InvalidArgument, ) from aggify.types import QueryParams from aggify.utilty import ( to_mongo_positive_index, check_fields_exist, replace_values_recursive, - convert_match_query, + convert_match_query, check_field_exists, get_db_field, ) @@ -174,14 +175,14 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: split_query = key.split("__") # Retrieve the field definition from the model. - join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # noqa - + join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # Check conditions for creating a 'match' pipeline stage. if ( - "document_type_obj" not in join_field.__dict__ - or issubclass(join_field.document_type, EmbeddedDocument) - or len(split_query) == 1 - or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) + isinstance(join_field, TopLevelDocumentMetaclass) + or "document_type_obj" not in join_field.__dict__ + or issubclass(join_field.document_type, EmbeddedDocument) + or len(split_query) == 1 + or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) ): # Create a 'match' pipeline stage. match = self.__match({key: value}) @@ -191,7 +192,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: self.pipelines.append(match) else: - from_collection = join_field.document_type # noqa + from_collection = join_field.document_type local_field = join_field.db_field as_name = join_field.name matches = [] @@ -210,7 +211,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: as_name=as_name, ) ) - self.unwind(as_name) + self.unwind(as_name, preserve=True) self.pipelines.extend([{"$match": match} for match in matches]) @last_out_stage_check @@ -228,7 +229,7 @@ def __getitem__(self, index: slice | int) -> "Aggify": @last_out_stage_check def unwind( - self, path: str, include_index_array: str | None = None, preserve: bool = False + self, path: str, include_index_array: str | None = None, preserve: bool = False ) -> "Aggify": """Generates a MongoDB unwind pipeline stage. @@ -286,7 +287,7 @@ def aggregate(self): return self.base_model.objects.aggregate(*self.pipelines) # type: ignore def annotate( - self, annotate_name: str, accumulator: str, f: str | dict | F | int + self, annotate_name: str, accumulator: str, f: str | dict | F | int ) -> "Aggify": """ Annotate a MongoDB aggregation pipeline with a new field. @@ -356,7 +357,7 @@ def annotate( else: if isinstance(f, str): try: - self.get_model_field(self.base_model, f) # noqa + self.get_model_field(self.base_model, f) value = f"${f}" except InvalidField: value = f @@ -382,7 +383,7 @@ def __match(self, matches: dict[str, Any]): @staticmethod def __lookup( - from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" + from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" ) -> dict[str, dict[str, str]]: """ Generates a MongoDB lookup pipeline stage. @@ -429,66 +430,85 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]: @last_out_stage_check def lookup( - self, from_collection: Document, let: list[str], query: list[Q], as_name: str + self, from_collection: Document, as_name: str, query: list[Q] | Q | None = None, + let: list[str] | None = None, local_field: str | None = None, foreign_field: str | None = None ) -> "Aggify": """ Generates a MongoDB lookup pipeline stage. Args: - from_collection (Document): The name of the collection to lookup. - let (list): The local field(s) to join on. - query (list[Q]): List of desired queries with Q function. + from_collection (Document): The document representing the collection to perform the lookup on. as_name (str): The name of the new field to create. + query (list[Q] | Q | None, optional): List of desired queries with Q function or a single query. + let (list[str] | None, optional): The local field(s) to join on. If provided, localField and foreignField are not used. + local_field (str | None, optional): The local field to join on when let is not provided. + foreign_field (str | None, optional): The foreign field to join on when let is not provided. Returns: - Aggify: A MongoDB lookup pipeline stage. + Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage. """ - check_fields_exist(self.base_model, let) # noqa - - let_dict = { - field: f"${self.base_model._fields[field].db_field}" - for field in let # noqa - } - from_collection = from_collection._meta.get("collection") # noqa lookup_stages = [] - - for q in query: - # Construct the match stage for each query - if isinstance(q, Q): - replaced_values = replace_values_recursive( - convert_match_query(dict(q)), # noqa - {field: f"$${field}" for field in let}, - ) - match_stage = {"$match": {"$expr": replaced_values.get("$match")}} - lookup_stages.append(match_stage) - elif isinstance(q, Aggify): - lookup_stages.extend( - replace_values_recursive( - convert_match_query(q.pipelines), # noqa + check_field_exists(self.base_model, as_name) + from_collection_name = from_collection._meta.get("collection") # noqa + + if not let and not (local_field and foreign_field): + raise InvalidArgument(expected_list=[['local_field', 'foreign_field'], 'let']) + elif not let: + if not (local_field and foreign_field): + raise InvalidArgument(expected_list=['local_field', 'foreign_field']) + lookup_stage = { + "$lookup": { + "from": from_collection_name, + 'localField': get_db_field(self.base_model, local_field), # noqa + 'foreignField': get_db_field(from_collection, foreign_field), # noqa + "as": as_name, + } + } + else: + if not query: + raise InvalidArgument(expected_list=['query']) + check_fields_exist(self.base_model, let) # noqa + let_dict = { + field: f"${get_db_field(self.base_model, field)}" + for field in let # noqa + } + for q in query: + # Construct the match stage for each query + if isinstance(q, Q): + replaced_values = replace_values_recursive( + convert_match_query(dict(q)), {field: f"$${field}" for field in let}, ) - ) + match_stage = {"$match": {"$expr": replaced_values.get("$match")}} + lookup_stages.append(match_stage) + elif isinstance(q, Aggify): + lookup_stages.extend( + replace_values_recursive( + convert_match_query(q.pipelines), # noqa + {field: f"$${field}" for field in let}, + ) + ) - # Append the lookup stage with multiple match stages to the pipeline - lookup_stage = { - "$lookup": { - "from": from_collection, - "let": let_dict, - "pipeline": lookup_stages, # List of match stages - "as": as_name, + # Append the lookup stage with multiple match stages to the pipeline + lookup_stage = { + "$lookup": { + "from": from_collection_name, + "let": let_dict, + "pipeline": lookup_stages, # List of match stages + "as": as_name, + } } - } self.pipelines.append(lookup_stage) # Add this new field to base model fields, which we can use it in the next stages. - self.base_model._fields[as_name] = fields.StringField() # noqa + self.base_model._fields[as_name] = from_collection # noqa return self @staticmethod - def get_model_field(model: Document, field: str) -> fields: + def get_model_field(model: Type[Document], field: str) -> fields: """ Get the field definition of a specified field in a MongoDB model. @@ -520,10 +540,10 @@ def _replace_base(self, embedded_field) -> str: Raises: InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type. """ - model_field = self.get_model_field(self.base_model, embedded_field) # noqa + model_field = self.get_model_field(self.base_model, embedded_field) if not hasattr(model_field, "document_type") or not issubclass( - model_field.document_type, EmbeddedDocument + model_field.document_type, EmbeddedDocument ): raise InvalidEmbeddedField(field=embedded_field) @@ -531,7 +551,7 @@ def _replace_base(self, embedded_field) -> str: @last_out_stage_check def replace_root( - self, *, embedded_field: str, merge: dict | None = None + self, *, embedded_field: str, merge: dict | None = None ) -> "Aggify": """ Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. @@ -559,7 +579,7 @@ def replace_root( @last_out_stage_check def replace_with( - self, *, embedded_field: str, merge: dict | None = None + self, *, embedded_field: str, merge: dict | None = None ) -> "Aggify": """ Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. diff --git a/aggify/compiler.py b/aggify/compiler.py index 6b50342..8529d94 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -1,8 +1,10 @@ from typing import Any, Type from mongoengine import Document, EmbeddedDocumentField +from mongoengine.base import TopLevelDocumentMetaclass from aggify.exceptions import InvalidOperator +from aggify.utilty import get_db_field class Operators: @@ -232,15 +234,15 @@ def validate_operator(key: str): raise InvalidOperator(operator) def is_base_model_field(self, field) -> bool: - return self.base_model is not None and isinstance( - self.base_model._fields.get(field), # type: ignore # noqa - EmbeddedDocumentField, + return self.base_model is not None and ( + isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa ) def compile(self, pipelines: list) -> dict[str, dict[str, list]]: match_query = {} for key, value in self.matches.items(): if "__" not in key: + key = get_db_field(self.base_model, key) match_query[key] = value continue @@ -259,7 +261,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]: if operator not in Operators.ALL_OPERATORS: raise InvalidOperator(operator) - - match_query = Operators(match_query).compile_match(operator, value, field) + db_field = get_db_field(self.base_model, field) + match_query = Operators(match_query).compile_match(operator, value, db_field) return {"$match": match_query} From 4d8c5dc5095994d07ffa5acba0a835e59219f897 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Tue, 31 Oct 2023 12:22:39 +0330 Subject: [PATCH 2/6] Add field checker --- aggify/utilty.py | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/aggify/utilty.py b/aggify/utilty.py index 079dc46..194a230 100644 --- a/aggify/utilty.py +++ b/aggify/utilty.py @@ -1,8 +1,8 @@ -from typing import Any +from typing import Any, Type from mongoengine import Document -from aggify.exceptions import MongoIndexError, InvalidField +from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField def int_to_slice(final_index: int) -> slice: @@ -79,7 +79,7 @@ def replace_values_recursive(obj, replacements): def convert_match_query( - d: dict, + d: dict, ) -> dict[Any, list[str | Any] | dict] | list[dict] | dict: """ Recursively transform a dictionary to modify the structure of '$eq' and '$ne' operators. @@ -113,3 +113,36 @@ def convert_match_query( return [convert_match_query(item) for item in d] else: return d + + +def check_field_exists(model: Type[Document], field: str) -> None: + """ + Check if a field exists in the given model. + + Args: + model (Document): The model to check for the field. + field (str): The name of the field to check. + + Raises: + AlreadyExistsField: If the field already exists in the model. + """ + if model._fields.get(field): # noqa + raise AlreadyExistsField(field=field) + + +def get_db_field(model: Type[Document], field: str) -> str: + """ + Get the database field name for a given field in the model. + + Args: + model (Document): The model containing the field. + field (str): The name of the field. + + Returns: + str: The database field name if available, otherwise the original field name. + """ + try: + db_field = model._fields.get(field).db_field # noqa + return field if db_field is None else db_field + except AttributeError: + return field From c1e116621ed0c98e850e2f7dc8332eea373a463f Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Tue, 31 Oct 2023 12:23:03 +0330 Subject: [PATCH 3/6] Add lookup new cases --- tests/test_query.py | 115 +++++++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 56 deletions(-) diff --git a/tests/test_query.py b/tests/test_query.py index c8e21e1..0dd131f 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -60,20 +60,15 @@ class ParameterTestCase: compiled_query=Aggify(PostDocument).filter( caption__contains="hello", owner__deleted_at=None ), - expected_query=[ - {"$match": {"caption": {"$options": "i", "$regex": ".*hello.*"}}}, - { - "$lookup": { - "as": "owner", - "foreignField": "_id", - "from": "account", - "localField": "owner_id", - } - }, - {"$unwind": "$owner"}, - {"$match": {"owner.deleted_at": None}}, - - ], + expected_query=[{'$match': {'caption': {'$options': 'i', '$regex': '.*hello.*'}}}, + {'$lookup': {'as': 'owner', + 'foreignField': '_id', + 'from': 'account', + 'localField': 'owner_id'}}, + {'$unwind': {'includeArrayIndex': None, + 'path': '$owner', + 'preserveNullAndEmptyArrays': True}}, + {'$match': {'owner.deleted_at': None}}], ), ParameterTestCase( compiled_query=Aggify(PostDocument) @@ -146,13 +141,13 @@ class ParameterTestCase: Q(_id__ne="owner") & Q(username__ne="seyed"), ], let=["owner"], - as_name="posts", + as_name="_posts", ) ), expected_query=[ { "$lookup": { - "as": "posts", + "as": "_posts", "from": "account", "let": {"owner": "$owner_id"}, "pipeline": [ @@ -180,32 +175,18 @@ class ParameterTestCase: Q(_id__ne="owner") & Q(username__ne="seyed"), ], let=["owner"], - as_name="posts", + as_name="_posts1", ) - .filter(posts__ne=[]) + .filter(_posts1__ne=[]) ), - expected_query=[ - { - "$lookup": { - "as": "posts", - "from": "account", - "let": {"owner": "$owner_id"}, - "pipeline": [ - { - "$match": { - "$expr": { - "$and": [ - {"$ne": ["$_id", "$$owner"]}, - {"$ne": ["$username", "seyed"]}, - ] - } - } - } - ], - } - }, - {"$match": {"posts": {"$ne": []}}}, - ], + expected_query=[{'$lookup': {'as': '_posts1', + 'from': 'account', + 'let': {'owner': '$owner_id'}, + 'pipeline': [{'$match': {'$expr': {'$and': [{'$ne': ['$_id', + '$$owner']}, + {'$ne': ['$username', + 'seyed']}]}}}]}}, + {'$match': {'_posts1': {'$ne': []}}}], ), ParameterTestCase( compiled_query=( @@ -214,24 +195,17 @@ class ParameterTestCase: AccountDocument, query=[Q(_id__exact="owner"), Q(username__exact="caption")], # noqa let=["owner", "caption"], - as_name="posts", + as_name="_posts2", ) - .filter(posts__ne=[]) + .filter(_posts2__ne=[]) ), - expected_query=[ - { - "$lookup": { - "as": "posts", - "from": "account", - "let": {"caption": "$caption", "owner": "$owner_id"}, - "pipeline": [ - {"$match": {"$expr": {"$eq": ["$_id", "$$owner"]}}}, - {"$match": {"$expr": {"$eq": ["$username", "$$caption"]}}}, - ], - } - }, - {"$match": {"posts": {"$ne": []}}}, - ], + expected_query=[{'$lookup': {'as': '_posts2', + 'from': 'account', + 'let': {'caption': '$caption', 'owner': '$owner_id'}, + 'pipeline': [{'$match': {'$expr': {'$eq': ['$_id', '$$owner']}}}, + {'$match': {'$expr': {'$eq': ['$username', + '$$caption']}}}]}}, + {'$match': {'_posts2': {'$ne': []}}}], ), ParameterTestCase( compiled_query=(Aggify(PostDocument).replace_root(embedded_field="stat")), @@ -279,6 +253,35 @@ class ParameterTestCase: } ], ), + ParameterTestCase( + compiled_query=( + Aggify(PostDocument) + .lookup( + AccountDocument, + local_field='owner', foreign_field='id', + as_name="_owner", + ) + ), + expected_query=[{'$lookup': {'as': '_owner', + 'foreignField': '_id', + 'from': 'account', + 'localField': 'owner_id'}}], + ), + ParameterTestCase( + compiled_query=( + Aggify(PostDocument) + .lookup( + AccountDocument, + local_field='owner', foreign_field='id', + as_name="_owner1", + ).filter(_owner1__username='Aggify') + ), + expected_query=[{'$lookup': {'as': '_owner1', + 'foreignField': '_id', + 'from': 'account', + 'localField': 'owner_id'}}, + {'$match': {'_owner1.username': 'Aggify'}}], + ), ] From 1fadfa3305df3c16ed3e4d44572718e2795c822c Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Tue, 31 Oct 2023 12:23:31 +0330 Subject: [PATCH 4/6] Fix operator problem --- aggify/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aggify/compiler.py b/aggify/compiler.py index 8529d94..e53a3d8 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -251,7 +251,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]: raise InvalidOperator(key) field, operator, *_ = key.split("__") - if self.is_base_model_field(field): + if self.is_base_model_field(field) and operator not in Operators.ALL_OPERATORS: pipelines.append( Match({key.replace("__", ".", 1): value}, self.base_model).compile( [] From 467ac8d57c8bcfba7b094d870ec93869df059e09 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Tue, 31 Oct 2023 12:23:49 +0330 Subject: [PATCH 5/6] Add AlreadyExistsField --- aggify/exceptions.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/aggify/exceptions.py b/aggify/exceptions.py index 04a9b89..089e13b 100644 --- a/aggify/exceptions.py +++ b/aggify/exceptions.py @@ -33,6 +33,7 @@ def __init__(self, stage): class AggifyValueError(AggifyBaseException): + def __init__(self, expected_list: list[Type], result: Type): self.message = ( f"Input is not correctly passed, expected either of {[expected for expected in expected_list]}" @@ -59,3 +60,19 @@ class InvalidEmbeddedField(AggifyBaseException): def __init__(self, field: str): self.message = f"Field {field} is not embedded." super().__init__(self.message) + + +class AlreadyExistsField(AggifyBaseException): + def __init__(self, field: str): + self.message = f"Field {field} already exists." + super().__init__(self.message) + + +class InvalidArgument(AggifyBaseException): + + def __init__(self, expected_list: list): + self.message = ( + f"Input is not correctly passed, expected {[expected for expected in expected_list]}" + ) + self.expecteds = expected_list + super().__init__(self.message) \ No newline at end of file From 39bdb406c109483da9e5ce2d07e909b401ef7020 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Tue, 31 Oct 2023 12:58:12 +0330 Subject: [PATCH 6/6] Reformat code with Black --- aggify/aggify.py | 57 ++++++++++++++++++++++++++++---------------- aggify/compiler.py | 12 ++++++++++ aggify/exceptions.py | 8 ++----- aggify/utilty.py | 16 ++++++------- 4 files changed, 59 insertions(+), 34 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 64b735c..051df8e 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -10,14 +10,17 @@ AnnotationError, InvalidField, InvalidEmbeddedField, - OutStageError, InvalidArgument, + OutStageError, + InvalidArgument, ) from aggify.types import QueryParams from aggify.utilty import ( to_mongo_positive_index, check_fields_exist, replace_values_recursive, - convert_match_query, check_field_exists, get_db_field, + convert_match_query, + check_field_exists, + get_db_field, ) @@ -178,11 +181,16 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # Check conditions for creating a 'match' pipeline stage. if ( - isinstance(join_field, TopLevelDocumentMetaclass) - or "document_type_obj" not in join_field.__dict__ - or issubclass(join_field.document_type, EmbeddedDocument) - or len(split_query) == 1 - or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) + isinstance( + join_field, TopLevelDocumentMetaclass + ) # check whether field is added by lookup stage or not + or "document_type_obj" + not in join_field.__dict__ # Check whether this field is a join field or not. + or issubclass( + join_field.document_type, EmbeddedDocument + ) # Check whether this field is embedded field or not + or len(split_query) == 1 + or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) ): # Create a 'match' pipeline stage. match = self.__match({key: value}) @@ -229,7 +237,7 @@ def __getitem__(self, index: slice | int) -> "Aggify": @last_out_stage_check def unwind( - self, path: str, include_index_array: str | None = None, preserve: bool = False + self, path: str, include_index_array: str | None = None, preserve: bool = False ) -> "Aggify": """Generates a MongoDB unwind pipeline stage. @@ -287,7 +295,7 @@ def aggregate(self): return self.base_model.objects.aggregate(*self.pipelines) # type: ignore def annotate( - self, annotate_name: str, accumulator: str, f: str | dict | F | int + self, annotate_name: str, accumulator: str, f: str | dict | F | int ) -> "Aggify": """ Annotate a MongoDB aggregation pipeline with a new field. @@ -383,7 +391,7 @@ def __match(self, matches: dict[str, Any]): @staticmethod def __lookup( - from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" + from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" ) -> dict[str, dict[str, str]]: """ Generates a MongoDB lookup pipeline stage. @@ -430,8 +438,13 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]: @last_out_stage_check def lookup( - self, from_collection: Document, as_name: str, query: list[Q] | Q | None = None, - let: list[str] | None = None, local_field: str | None = None, foreign_field: str | None = None + self, + from_collection: Document, + as_name: str, + query: list[Q] | Q | None = None, + let: list[str] | None = None, + local_field: str | None = None, + foreign_field: str | None = None, ) -> "Aggify": """ Generates a MongoDB lookup pipeline stage. @@ -453,21 +466,25 @@ def lookup( from_collection_name = from_collection._meta.get("collection") # noqa if not let and not (local_field and foreign_field): - raise InvalidArgument(expected_list=[['local_field', 'foreign_field'], 'let']) + raise InvalidArgument( + expected_list=[["local_field", "foreign_field"], "let"] + ) elif not let: if not (local_field and foreign_field): - raise InvalidArgument(expected_list=['local_field', 'foreign_field']) + raise InvalidArgument(expected_list=["local_field", "foreign_field"]) lookup_stage = { "$lookup": { "from": from_collection_name, - 'localField': get_db_field(self.base_model, local_field), # noqa - 'foreignField': get_db_field(from_collection, foreign_field), # noqa + "localField": get_db_field(self.base_model, local_field), # noqa + "foreignField": get_db_field( + from_collection, foreign_field + ), # noqa "as": as_name, } } else: if not query: - raise InvalidArgument(expected_list=['query']) + raise InvalidArgument(expected_list=["query"]) check_fields_exist(self.base_model, let) # noqa let_dict = { field: f"${get_db_field(self.base_model, field)}" @@ -543,7 +560,7 @@ def _replace_base(self, embedded_field) -> str: model_field = self.get_model_field(self.base_model, embedded_field) if not hasattr(model_field, "document_type") or not issubclass( - model_field.document_type, EmbeddedDocument + model_field.document_type, EmbeddedDocument ): raise InvalidEmbeddedField(field=embedded_field) @@ -551,7 +568,7 @@ def _replace_base(self, embedded_field) -> str: @last_out_stage_check def replace_root( - self, *, embedded_field: str, merge: dict | None = None + self, *, embedded_field: str, merge: dict | None = None ) -> "Aggify": """ Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. @@ -579,7 +596,7 @@ def replace_root( @last_out_stage_check def replace_with( - self, *, embedded_field: str, merge: dict | None = None + self, *, embedded_field: str, merge: dict | None = None ) -> "Aggify": """ Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. diff --git a/aggify/compiler.py b/aggify/compiler.py index e53a3d8..6b2a97b 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -234,6 +234,18 @@ def validate_operator(key: str): raise InvalidOperator(operator) def is_base_model_field(self, field) -> bool: + """ + Check if a field in the base model class is of a specific type. + EmbeddedDocumentField: Field which is embedded. + TopLevelDocumentMetaclass: Field which is added by lookup stage. + + Args: + field (str): The name of the field to check. + + Returns: + bool: True if the field is of type EmbeddedDocumentField or TopLevelDocumentMetaclass + and the base_model is not None, otherwise False. + """ return self.base_model is not None and ( isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa ) diff --git a/aggify/exceptions.py b/aggify/exceptions.py index 089e13b..e8edd06 100644 --- a/aggify/exceptions.py +++ b/aggify/exceptions.py @@ -33,7 +33,6 @@ def __init__(self, stage): class AggifyValueError(AggifyBaseException): - def __init__(self, expected_list: list[Type], result: Type): self.message = ( f"Input is not correctly passed, expected either of {[expected for expected in expected_list]}" @@ -69,10 +68,7 @@ def __init__(self, field: str): class InvalidArgument(AggifyBaseException): - def __init__(self, expected_list: list): - self.message = ( - f"Input is not correctly passed, expected {[expected for expected in expected_list]}" - ) + self.message = f"Input is not correctly passed, expected {[expected for expected in expected_list]}" self.expecteds = expected_list - super().__init__(self.message) \ No newline at end of file + super().__init__(self.message) diff --git a/aggify/utilty.py b/aggify/utilty.py index 194a230..bb6e7a3 100644 --- a/aggify/utilty.py +++ b/aggify/utilty.py @@ -79,7 +79,7 @@ def replace_values_recursive(obj, replacements): def convert_match_query( - d: dict, + d: dict, ) -> dict[Any, list[str | Any] | dict] | list[dict] | dict: """ Recursively transform a dictionary to modify the structure of '$eq' and '$ne' operators. @@ -117,15 +117,15 @@ def convert_match_query( def check_field_exists(model: Type[Document], field: str) -> None: """ - Check if a field exists in the given model. + Check if a field exists in the given model. - Args: - model (Document): The model to check for the field. - field (str): The name of the field to check. + Args: + model (Document): The model to check for the field. + field (str): The name of the field to check. - Raises: - AlreadyExistsField: If the field already exists in the model. - """ + Raises: + AlreadyExistsField: If the field already exists in the model. + """ if model._fields.get(field): # noqa raise AlreadyExistsField(field=field)