Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #20 #23

Merged
merged 6 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 76 additions & 56 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import functools
from typing import Any, Literal, Dict, Type
from typing import Any, Dict, Type

from mongoengine import Document, EmbeddedDocument, fields
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.compiler import F, Match, Q, Operators # noqa keep
from aggify.exceptions import (
AggifyValueError,
AnnotationError,
InvalidField,
InvalidEmbeddedField,
OutStageError,
OutStageError, InvalidArgument,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format code

)
from aggify.types import QueryParams
from aggify.utilty import (
to_mongo_positive_index,
check_fields_exist,
replace_values_recursive,
convert_match_query,
convert_match_query, check_field_exists, get_db_field,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format code

)


Expand Down Expand Up @@ -174,14 +175,14 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
split_query = key.split("__")

# Retrieve the field definition from the model.
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # noqa

join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore
# Check conditions for creating a 'match' pipeline stage.
if (
"document_type_obj" not in join_field.__dict__
or issubclass(join_field.document_type, EmbeddedDocument)
or len(split_query) == 1
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
isinstance(join_field, TopLevelDocumentMetaclass)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment and description about what do this code

or "document_type_obj" not in join_field.__dict__
or issubclass(join_field.document_type, EmbeddedDocument)
or len(split_query) == 1
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
):
# Create a 'match' pipeline stage.
match = self.__match({key: value})
Expand All @@ -191,7 +192,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
self.pipelines.append(match)

else:
from_collection = join_field.document_type # noqa
from_collection = join_field.document_type
local_field = join_field.db_field
as_name = join_field.name
matches = []
Expand All @@ -210,7 +211,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
as_name=as_name,
)
)
self.unwind(as_name)
self.unwind(as_name, preserve=True)
self.pipelines.extend([{"$match": match} for match in matches])

@last_out_stage_check
Expand All @@ -228,7 +229,7 @@ def __getitem__(self, index: slice | int) -> "Aggify":

@last_out_stage_check
def unwind(
self, path: str, include_index_array: str | None = None, preserve: bool = False
self, path: str, include_index_array: str | None = None, preserve: bool = False
) -> "Aggify":
"""Generates a MongoDB unwind pipeline stage.

Expand Down Expand Up @@ -286,7 +287,7 @@ def aggregate(self):
return self.base_model.objects.aggregate(*self.pipelines) # type: ignore

def annotate(
self, annotate_name: str, accumulator: str, f: str | dict | F | int
self, annotate_name: str, accumulator: str, f: str | dict | F | int
) -> "Aggify":
"""
Annotate a MongoDB aggregation pipeline with a new field.
Expand Down Expand Up @@ -356,7 +357,7 @@ def annotate(
else:
if isinstance(f, str):
try:
self.get_model_field(self.base_model, f) # noqa
self.get_model_field(self.base_model, f)
value = f"${f}"
except InvalidField:
value = f
Expand All @@ -382,7 +383,7 @@ def __match(self, matches: dict[str, Any]):

@staticmethod
def __lookup(
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
) -> dict[str, dict[str, str]]:
"""
Generates a MongoDB lookup pipeline stage.
Expand Down Expand Up @@ -429,66 +430,85 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]:

@last_out_stage_check
def lookup(
self, from_collection: Document, let: list[str], query: list[Q], as_name: str
self, from_collection: Document, as_name: str, query: list[Q] | Q | None = None,
let: list[str] | None = None, local_field: str | None = None, foreign_field: str | None = None
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.

Args:
from_collection (Document): The name of the collection to lookup.
let (list): The local field(s) to join on.
query (list[Q]): List of desired queries with Q function.
from_collection (Document): The document representing the collection to perform the lookup on.
as_name (str): The name of the new field to create.
query (list[Q] | Q | None, optional): List of desired queries with Q function or a single query.
let (list[str] | None, optional): The local field(s) to join on. If provided, localField and foreignField are not used.
local_field (str | None, optional): The local field to join on when let is not provided.
foreign_field (str | None, optional): The foreign field to join on when let is not provided.

Returns:
Aggify: A MongoDB lookup pipeline stage.
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
"""
check_fields_exist(self.base_model, let) # noqa

let_dict = {
field: f"${self.base_model._fields[field].db_field}"
for field in let # noqa
}
from_collection = from_collection._meta.get("collection") # noqa

lookup_stages = []

for q in query:
# Construct the match stage for each query
if isinstance(q, Q):
replaced_values = replace_values_recursive(
convert_match_query(dict(q)), # noqa
{field: f"$${field}" for field in let},
)
match_stage = {"$match": {"$expr": replaced_values.get("$match")}}
lookup_stages.append(match_stage)
elif isinstance(q, Aggify):
lookup_stages.extend(
replace_values_recursive(
convert_match_query(q.pipelines), # noqa
check_field_exists(self.base_model, as_name)
from_collection_name = from_collection._meta.get("collection") # noqa

if not let and not (local_field and foreign_field):
raise InvalidArgument(expected_list=[['local_field', 'foreign_field'], 'let'])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace ' with "

elif not let:
if not (local_field and foreign_field):
raise InvalidArgument(expected_list=['local_field', 'foreign_field'])
lookup_stage = {
"$lookup": {
"from": from_collection_name,
'localField': get_db_field(self.base_model, local_field), # noqa
'foreignField': get_db_field(from_collection, foreign_field), # noqa
"as": as_name,
}
}
else:
if not query:
raise InvalidArgument(expected_list=['query'])
check_fields_exist(self.base_model, let) # noqa
let_dict = {
field: f"${get_db_field(self.base_model, field)}"
for field in let # noqa
}
for q in query:
# Construct the match stage for each query
if isinstance(q, Q):
replaced_values = replace_values_recursive(
convert_match_query(dict(q)),
{field: f"$${field}" for field in let},
)
)
match_stage = {"$match": {"$expr": replaced_values.get("$match")}}
lookup_stages.append(match_stage)
elif isinstance(q, Aggify):
lookup_stages.extend(
replace_values_recursive(
convert_match_query(q.pipelines), # noqa
{field: f"$${field}" for field in let},
)
)

# Append the lookup stage with multiple match stages to the pipeline
lookup_stage = {
"$lookup": {
"from": from_collection,
"let": let_dict,
"pipeline": lookup_stages, # List of match stages
"as": as_name,
# Append the lookup stage with multiple match stages to the pipeline
lookup_stage = {
"$lookup": {
"from": from_collection_name,
"let": let_dict,
"pipeline": lookup_stages, # List of match stages
"as": as_name,
}
}
}

self.pipelines.append(lookup_stage)

# Add this new field to base model fields, which we can use it in the next stages.
self.base_model._fields[as_name] = fields.StringField() # noqa
self.base_model._fields[as_name] = from_collection # noqa

return self

@staticmethod
def get_model_field(model: Document, field: str) -> fields:
def get_model_field(model: Type[Document], field: str) -> fields:
"""
Get the field definition of a specified field in a MongoDB model.

Expand Down Expand Up @@ -520,18 +540,18 @@ def _replace_base(self, embedded_field) -> str:
Raises:
InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type.
"""
model_field = self.get_model_field(self.base_model, embedded_field) # noqa
model_field = self.get_model_field(self.base_model, embedded_field)

if not hasattr(model_field, "document_type") or not issubclass(
model_field.document_type, EmbeddedDocument
model_field.document_type, EmbeddedDocument
):
raise InvalidEmbeddedField(field=embedded_field)

return f"${model_field.db_field}"

@last_out_stage_check
def replace_root(
self, *, embedded_field: str, merge: dict | None = None
self, *, embedded_field: str, merge: dict | None = None
) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down Expand Up @@ -559,7 +579,7 @@ def replace_root(

@last_out_stage_check
def replace_with(
self, *, embedded_field: str, merge: dict | None = None
self, *, embedded_field: str, merge: dict | None = None
) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down
12 changes: 7 additions & 5 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Type

from mongoengine import Document, EmbeddedDocumentField
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.exceptions import InvalidOperator
from aggify.utilty import get_db_field


class Operators:
Expand Down Expand Up @@ -232,15 +234,15 @@ def validate_operator(key: str):
raise InvalidOperator(operator)

def is_base_model_field(self, field) -> bool:
return self.base_model is not None and isinstance(
self.base_model._fields.get(field), # type: ignore # noqa
EmbeddedDocumentField,
return self.base_model is not None and (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add doc string

isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa
)

def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
match_query = {}
for key, value in self.matches.items():
if "__" not in key:
key = get_db_field(self.base_model, key)
match_query[key] = value
continue

Expand All @@ -259,7 +261,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]:

if operator not in Operators.ALL_OPERATORS:
raise InvalidOperator(operator)

match_query = Operators(match_query).compile_match(operator, value, field)
db_field = get_db_field(self.base_model, field)
match_query = Operators(match_query).compile_match(operator, value, db_field)

return {"$match": match_query}