-
-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issue #20 #23
Fix issue #20 #23
Changes from 1 commit
ad980fb
4d8c5dc
c1e1166
1fadfa3
467ac8d
39bdb40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,23 @@ | ||
import functools | ||
from typing import Any, Literal, Dict, Type | ||
from typing import Any, Dict, Type | ||
|
||
from mongoengine import Document, EmbeddedDocument, fields | ||
from mongoengine.base import TopLevelDocumentMetaclass | ||
|
||
from aggify.compiler import F, Match, Q, Operators # noqa keep | ||
from aggify.exceptions import ( | ||
AggifyValueError, | ||
AnnotationError, | ||
InvalidField, | ||
InvalidEmbeddedField, | ||
OutStageError, | ||
OutStageError, InvalidArgument, | ||
) | ||
from aggify.types import QueryParams | ||
from aggify.utilty import ( | ||
to_mongo_positive_index, | ||
check_fields_exist, | ||
replace_values_recursive, | ||
convert_match_query, | ||
convert_match_query, check_field_exists, get_db_field, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. format code |
||
) | ||
|
||
|
||
|
@@ -174,14 +175,14 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: | |
split_query = key.split("__") | ||
|
||
# Retrieve the field definition from the model. | ||
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # noqa | ||
|
||
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore | ||
# Check conditions for creating a 'match' pipeline stage. | ||
if ( | ||
"document_type_obj" not in join_field.__dict__ | ||
or issubclass(join_field.document_type, EmbeddedDocument) | ||
or len(split_query) == 1 | ||
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) | ||
isinstance(join_field, TopLevelDocumentMetaclass) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment and description about what do this code |
||
or "document_type_obj" not in join_field.__dict__ | ||
or issubclass(join_field.document_type, EmbeddedDocument) | ||
or len(split_query) == 1 | ||
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) | ||
): | ||
# Create a 'match' pipeline stage. | ||
match = self.__match({key: value}) | ||
|
@@ -191,7 +192,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: | |
self.pipelines.append(match) | ||
|
||
else: | ||
from_collection = join_field.document_type # noqa | ||
from_collection = join_field.document_type | ||
local_field = join_field.db_field | ||
as_name = join_field.name | ||
matches = [] | ||
|
@@ -210,7 +211,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: | |
as_name=as_name, | ||
) | ||
) | ||
self.unwind(as_name) | ||
self.unwind(as_name, preserve=True) | ||
self.pipelines.extend([{"$match": match} for match in matches]) | ||
|
||
@last_out_stage_check | ||
|
@@ -228,7 +229,7 @@ def __getitem__(self, index: slice | int) -> "Aggify": | |
|
||
@last_out_stage_check | ||
def unwind( | ||
self, path: str, include_index_array: str | None = None, preserve: bool = False | ||
self, path: str, include_index_array: str | None = None, preserve: bool = False | ||
) -> "Aggify": | ||
"""Generates a MongoDB unwind pipeline stage. | ||
|
||
|
@@ -286,7 +287,7 @@ def aggregate(self): | |
return self.base_model.objects.aggregate(*self.pipelines) # type: ignore | ||
|
||
def annotate( | ||
self, annotate_name: str, accumulator: str, f: str | dict | F | int | ||
self, annotate_name: str, accumulator: str, f: str | dict | F | int | ||
) -> "Aggify": | ||
""" | ||
Annotate a MongoDB aggregation pipeline with a new field. | ||
|
@@ -356,7 +357,7 @@ def annotate( | |
else: | ||
if isinstance(f, str): | ||
try: | ||
self.get_model_field(self.base_model, f) # noqa | ||
self.get_model_field(self.base_model, f) | ||
value = f"${f}" | ||
except InvalidField: | ||
value = f | ||
|
@@ -382,7 +383,7 @@ def __match(self, matches: dict[str, Any]): | |
|
||
@staticmethod | ||
def __lookup( | ||
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" | ||
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" | ||
) -> dict[str, dict[str, str]]: | ||
""" | ||
Generates a MongoDB lookup pipeline stage. | ||
|
@@ -429,66 +430,85 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]: | |
|
||
@last_out_stage_check | ||
def lookup( | ||
self, from_collection: Document, let: list[str], query: list[Q], as_name: str | ||
self, from_collection: Document, as_name: str, query: list[Q] | Q | None = None, | ||
let: list[str] | None = None, local_field: str | None = None, foreign_field: str | None = None | ||
) -> "Aggify": | ||
""" | ||
Generates a MongoDB lookup pipeline stage. | ||
|
||
Args: | ||
from_collection (Document): The name of the collection to lookup. | ||
let (list): The local field(s) to join on. | ||
query (list[Q]): List of desired queries with Q function. | ||
from_collection (Document): The document representing the collection to perform the lookup on. | ||
as_name (str): The name of the new field to create. | ||
query (list[Q] | Q | None, optional): List of desired queries with Q function or a single query. | ||
let (list[str] | None, optional): The local field(s) to join on. If provided, localField and foreignField are not used. | ||
local_field (str | None, optional): The local field to join on when let is not provided. | ||
foreign_field (str | None, optional): The foreign field to join on when let is not provided. | ||
|
||
Returns: | ||
Aggify: A MongoDB lookup pipeline stage. | ||
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage. | ||
""" | ||
check_fields_exist(self.base_model, let) # noqa | ||
|
||
let_dict = { | ||
field: f"${self.base_model._fields[field].db_field}" | ||
for field in let # noqa | ||
} | ||
from_collection = from_collection._meta.get("collection") # noqa | ||
|
||
lookup_stages = [] | ||
|
||
for q in query: | ||
# Construct the match stage for each query | ||
if isinstance(q, Q): | ||
replaced_values = replace_values_recursive( | ||
convert_match_query(dict(q)), # noqa | ||
{field: f"$${field}" for field in let}, | ||
) | ||
match_stage = {"$match": {"$expr": replaced_values.get("$match")}} | ||
lookup_stages.append(match_stage) | ||
elif isinstance(q, Aggify): | ||
lookup_stages.extend( | ||
replace_values_recursive( | ||
convert_match_query(q.pipelines), # noqa | ||
check_field_exists(self.base_model, as_name) | ||
from_collection_name = from_collection._meta.get("collection") # noqa | ||
|
||
if not let and not (local_field and foreign_field): | ||
raise InvalidArgument(expected_list=[['local_field', 'foreign_field'], 'let']) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace ' with " |
||
elif not let: | ||
if not (local_field and foreign_field): | ||
raise InvalidArgument(expected_list=['local_field', 'foreign_field']) | ||
lookup_stage = { | ||
"$lookup": { | ||
"from": from_collection_name, | ||
'localField': get_db_field(self.base_model, local_field), # noqa | ||
'foreignField': get_db_field(from_collection, foreign_field), # noqa | ||
"as": as_name, | ||
} | ||
} | ||
else: | ||
if not query: | ||
raise InvalidArgument(expected_list=['query']) | ||
check_fields_exist(self.base_model, let) # noqa | ||
let_dict = { | ||
field: f"${get_db_field(self.base_model, field)}" | ||
for field in let # noqa | ||
} | ||
for q in query: | ||
# Construct the match stage for each query | ||
if isinstance(q, Q): | ||
replaced_values = replace_values_recursive( | ||
convert_match_query(dict(q)), | ||
{field: f"$${field}" for field in let}, | ||
) | ||
) | ||
match_stage = {"$match": {"$expr": replaced_values.get("$match")}} | ||
lookup_stages.append(match_stage) | ||
elif isinstance(q, Aggify): | ||
lookup_stages.extend( | ||
replace_values_recursive( | ||
convert_match_query(q.pipelines), # noqa | ||
{field: f"$${field}" for field in let}, | ||
) | ||
) | ||
|
||
# Append the lookup stage with multiple match stages to the pipeline | ||
lookup_stage = { | ||
"$lookup": { | ||
"from": from_collection, | ||
"let": let_dict, | ||
"pipeline": lookup_stages, # List of match stages | ||
"as": as_name, | ||
# Append the lookup stage with multiple match stages to the pipeline | ||
lookup_stage = { | ||
"$lookup": { | ||
"from": from_collection_name, | ||
"let": let_dict, | ||
"pipeline": lookup_stages, # List of match stages | ||
"as": as_name, | ||
} | ||
} | ||
} | ||
|
||
self.pipelines.append(lookup_stage) | ||
|
||
# Add this new field to base model fields, which we can use it in the next stages. | ||
self.base_model._fields[as_name] = fields.StringField() # noqa | ||
self.base_model._fields[as_name] = from_collection # noqa | ||
|
||
return self | ||
|
||
@staticmethod | ||
def get_model_field(model: Document, field: str) -> fields: | ||
def get_model_field(model: Type[Document], field: str) -> fields: | ||
""" | ||
Get the field definition of a specified field in a MongoDB model. | ||
|
||
|
@@ -520,18 +540,18 @@ def _replace_base(self, embedded_field) -> str: | |
Raises: | ||
InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type. | ||
""" | ||
model_field = self.get_model_field(self.base_model, embedded_field) # noqa | ||
model_field = self.get_model_field(self.base_model, embedded_field) | ||
|
||
if not hasattr(model_field, "document_type") or not issubclass( | ||
model_field.document_type, EmbeddedDocument | ||
model_field.document_type, EmbeddedDocument | ||
): | ||
raise InvalidEmbeddedField(field=embedded_field) | ||
|
||
return f"${model_field.db_field}" | ||
|
||
@last_out_stage_check | ||
def replace_root( | ||
self, *, embedded_field: str, merge: dict | None = None | ||
self, *, embedded_field: str, merge: dict | None = None | ||
) -> "Aggify": | ||
""" | ||
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. | ||
|
@@ -559,7 +579,7 @@ def replace_root( | |
|
||
@last_out_stage_check | ||
def replace_with( | ||
self, *, embedded_field: str, merge: dict | None = None | ||
self, *, embedded_field: str, merge: dict | None = None | ||
) -> "Aggify": | ||
""" | ||
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
from typing import Any, Type | ||
|
||
from mongoengine import Document, EmbeddedDocumentField | ||
from mongoengine.base import TopLevelDocumentMetaclass | ||
|
||
from aggify.exceptions import InvalidOperator | ||
from aggify.utilty import get_db_field | ||
|
||
|
||
class Operators: | ||
|
@@ -232,15 +234,15 @@ def validate_operator(key: str): | |
raise InvalidOperator(operator) | ||
|
||
def is_base_model_field(self, field) -> bool: | ||
return self.base_model is not None and isinstance( | ||
self.base_model._fields.get(field), # type: ignore # noqa | ||
EmbeddedDocumentField, | ||
return self.base_model is not None and ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add doc string |
||
isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa | ||
) | ||
|
||
def compile(self, pipelines: list) -> dict[str, dict[str, list]]: | ||
match_query = {} | ||
for key, value in self.matches.items(): | ||
if "__" not in key: | ||
key = get_db_field(self.base_model, key) | ||
match_query[key] = value | ||
continue | ||
|
||
|
@@ -259,7 +261,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]: | |
|
||
if operator not in Operators.ALL_OPERATORS: | ||
raise InvalidOperator(operator) | ||
|
||
match_query = Operators(match_query).compile_match(operator, value, field) | ||
db_field = get_db_field(self.base_model, field) | ||
match_query = Operators(match_query).compile_match(operator, value, db_field) | ||
|
||
return {"$match": match_query} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format code