Skip to content

Commit 0e94799

Browse files
authored
Merge pull request #23 from mohamadkhalaj/main
Fix issue #20
2 parents d6e35c9 + 39bdb40 commit 0e94799

File tree

5 files changed

+209
-109
lines changed

5 files changed

+209
-109
lines changed

aggify/aggify.py

+82-45
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import functools
2-
from typing import Any, Literal, Dict, Type
2+
from typing import Any, Dict, Type
33

44
from mongoengine import Document, EmbeddedDocument, fields
5+
from mongoengine.base import TopLevelDocumentMetaclass
56

67
from aggify.compiler import F, Match, Q, Operators # noqa keep
78
from aggify.exceptions import (
@@ -10,13 +11,16 @@
1011
InvalidField,
1112
InvalidEmbeddedField,
1213
OutStageError,
14+
InvalidArgument,
1315
)
1416
from aggify.types import QueryParams
1517
from aggify.utilty import (
1618
to_mongo_positive_index,
1719
check_fields_exist,
1820
replace_values_recursive,
1921
convert_match_query,
22+
check_field_exists,
23+
get_db_field,
2024
)
2125

2226

@@ -174,12 +178,17 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
174178
split_query = key.split("__")
175179

176180
# Retrieve the field definition from the model.
177-
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # noqa
178-
181+
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore
179182
# Check conditions for creating a 'match' pipeline stage.
180183
if (
181-
"document_type_obj" not in join_field.__dict__
182-
or issubclass(join_field.document_type, EmbeddedDocument)
184+
isinstance(
185+
join_field, TopLevelDocumentMetaclass
186+
) # check whether field is added by lookup stage or not
187+
or "document_type_obj"
188+
not in join_field.__dict__ # Check whether this field is a join field or not.
189+
or issubclass(
190+
join_field.document_type, EmbeddedDocument
191+
) # Check whether this field is embedded field or not
183192
or len(split_query) == 1
184193
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
185194
):
@@ -191,7 +200,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
191200
self.pipelines.append(match)
192201

193202
else:
194-
from_collection = join_field.document_type # noqa
203+
from_collection = join_field.document_type
195204
local_field = join_field.db_field
196205
as_name = join_field.name
197206
matches = []
@@ -210,7 +219,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
210219
as_name=as_name,
211220
)
212221
)
213-
self.unwind(as_name)
222+
self.unwind(as_name, preserve=True)
214223
self.pipelines.extend([{"$match": match} for match in matches])
215224

216225
@last_out_stage_check
@@ -356,7 +365,7 @@ def annotate(
356365
else:
357366
if isinstance(f, str):
358367
try:
359-
self.get_model_field(self.base_model, f) # noqa
368+
self.get_model_field(self.base_model, f)
360369
value = f"${f}"
361370
except InvalidField:
362371
value = f
@@ -429,66 +438,94 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]:
429438

430439
@last_out_stage_check
431440
def lookup(
432-
self, from_collection: Document, let: list[str], query: list[Q], as_name: str
441+
self,
442+
from_collection: Document,
443+
as_name: str,
444+
query: list[Q] | Q | None = None,
445+
let: list[str] | None = None,
446+
local_field: str | None = None,
447+
foreign_field: str | None = None,
433448
) -> "Aggify":
434449
"""
435450
Generates a MongoDB lookup pipeline stage.
436451
437452
Args:
438-
from_collection (Document): The name of the collection to lookup.
439-
let (list): The local field(s) to join on.
440-
query (list[Q]): List of desired queries with Q function.
453+
from_collection (Document): The document representing the collection to perform the lookup on.
441454
as_name (str): The name of the new field to create.
455+
query (list[Q] | Q | None, optional): List of desired queries with Q function or a single query.
456+
let (list[str] | None, optional): The local field(s) to join on. If provided, localField and foreignField are not used.
457+
local_field (str | None, optional): The local field to join on when let is not provided.
458+
foreign_field (str | None, optional): The foreign field to join on when let is not provided.
442459
443460
Returns:
444-
Aggify: A MongoDB lookup pipeline stage.
461+
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
445462
"""
446-
check_fields_exist(self.base_model, let) # noqa
447-
448-
let_dict = {
449-
field: f"${self.base_model._fields[field].db_field}"
450-
for field in let # noqa
451-
}
452-
from_collection = from_collection._meta.get("collection") # noqa
453463

454464
lookup_stages = []
465+
check_field_exists(self.base_model, as_name)
466+
from_collection_name = from_collection._meta.get("collection") # noqa
455467

456-
for q in query:
457-
# Construct the match stage for each query
458-
if isinstance(q, Q):
459-
replaced_values = replace_values_recursive(
460-
convert_match_query(dict(q)), # noqa
461-
{field: f"$${field}" for field in let},
462-
)
463-
match_stage = {"$match": {"$expr": replaced_values.get("$match")}}
464-
lookup_stages.append(match_stage)
465-
elif isinstance(q, Aggify):
466-
lookup_stages.extend(
467-
replace_values_recursive(
468-
convert_match_query(q.pipelines), # noqa
468+
if not let and not (local_field and foreign_field):
469+
raise InvalidArgument(
470+
expected_list=[["local_field", "foreign_field"], "let"]
471+
)
472+
elif not let:
473+
if not (local_field and foreign_field):
474+
raise InvalidArgument(expected_list=["local_field", "foreign_field"])
475+
lookup_stage = {
476+
"$lookup": {
477+
"from": from_collection_name,
478+
"localField": get_db_field(self.base_model, local_field), # noqa
479+
"foreignField": get_db_field(
480+
from_collection, foreign_field
481+
), # noqa
482+
"as": as_name,
483+
}
484+
}
485+
else:
486+
if not query:
487+
raise InvalidArgument(expected_list=["query"])
488+
check_fields_exist(self.base_model, let) # noqa
489+
let_dict = {
490+
field: f"${get_db_field(self.base_model, field)}"
491+
for field in let # noqa
492+
}
493+
for q in query:
494+
# Construct the match stage for each query
495+
if isinstance(q, Q):
496+
replaced_values = replace_values_recursive(
497+
convert_match_query(dict(q)),
469498
{field: f"$${field}" for field in let},
470499
)
471-
)
500+
match_stage = {"$match": {"$expr": replaced_values.get("$match")}}
501+
lookup_stages.append(match_stage)
502+
elif isinstance(q, Aggify):
503+
lookup_stages.extend(
504+
replace_values_recursive(
505+
convert_match_query(q.pipelines), # noqa
506+
{field: f"$${field}" for field in let},
507+
)
508+
)
472509

473-
# Append the lookup stage with multiple match stages to the pipeline
474-
lookup_stage = {
475-
"$lookup": {
476-
"from": from_collection,
477-
"let": let_dict,
478-
"pipeline": lookup_stages, # List of match stages
479-
"as": as_name,
510+
# Append the lookup stage with multiple match stages to the pipeline
511+
lookup_stage = {
512+
"$lookup": {
513+
"from": from_collection_name,
514+
"let": let_dict,
515+
"pipeline": lookup_stages, # List of match stages
516+
"as": as_name,
517+
}
480518
}
481-
}
482519

483520
self.pipelines.append(lookup_stage)
484521

485522
# Add this new field to base model fields, which we can use it in the next stages.
486-
self.base_model._fields[as_name] = fields.StringField() # noqa
523+
self.base_model._fields[as_name] = from_collection # noqa
487524

488525
return self
489526

490527
@staticmethod
491-
def get_model_field(model: Document, field: str) -> fields:
528+
def get_model_field(model: Type[Document], field: str) -> fields:
492529
"""
493530
Get the field definition of a specified field in a MongoDB model.
494531
@@ -520,7 +557,7 @@ def _replace_base(self, embedded_field) -> str:
520557
Raises:
521558
InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type.
522559
"""
523-
model_field = self.get_model_field(self.base_model, embedded_field) # noqa
560+
model_field = self.get_model_field(self.base_model, embedded_field)
524561

525562
if not hasattr(model_field, "document_type") or not issubclass(
526563
model_field.document_type, EmbeddedDocument

aggify/compiler.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Any, Type
22

33
from mongoengine import Document, EmbeddedDocumentField
4+
from mongoengine.base import TopLevelDocumentMetaclass
45

56
from aggify.exceptions import InvalidOperator
7+
from aggify.utilty import get_db_field
68

79

810
class Operators:
@@ -232,15 +234,27 @@ def validate_operator(key: str):
232234
raise InvalidOperator(operator)
233235

234236
def is_base_model_field(self, field) -> bool:
235-
return self.base_model is not None and isinstance(
236-
self.base_model._fields.get(field), # type: ignore # noqa
237-
EmbeddedDocumentField,
237+
"""
238+
Check if a field in the base model class is of a specific type.
239+
EmbeddedDocumentField: Field which is embedded.
240+
TopLevelDocumentMetaclass: Field which is added by lookup stage.
241+
242+
Args:
243+
field (str): The name of the field to check.
244+
245+
Returns:
246+
bool: True if the field is of type EmbeddedDocumentField or TopLevelDocumentMetaclass
247+
and the base_model is not None, otherwise False.
248+
"""
249+
return self.base_model is not None and (
250+
isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa
238251
)
239252

240253
def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
241254
match_query = {}
242255
for key, value in self.matches.items():
243256
if "__" not in key:
257+
key = get_db_field(self.base_model, key)
244258
match_query[key] = value
245259
continue
246260

@@ -249,7 +263,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
249263
raise InvalidOperator(key)
250264

251265
field, operator, *_ = key.split("__")
252-
if self.is_base_model_field(field):
266+
if self.is_base_model_field(field) and operator not in Operators.ALL_OPERATORS:
253267
pipelines.append(
254268
Match({key.replace("__", ".", 1): value}, self.base_model).compile(
255269
[]
@@ -259,7 +273,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
259273

260274
if operator not in Operators.ALL_OPERATORS:
261275
raise InvalidOperator(operator)
262-
263-
match_query = Operators(match_query).compile_match(operator, value, field)
276+
db_field = get_db_field(self.base_model, field)
277+
match_query = Operators(match_query).compile_match(operator, value, db_field)
264278

265279
return {"$match": match_query}

aggify/exceptions.py

+13
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,16 @@ class InvalidEmbeddedField(AggifyBaseException):
5959
def __init__(self, field: str):
6060
self.message = f"Field {field} is not embedded."
6161
super().__init__(self.message)
62+
63+
64+
class AlreadyExistsField(AggifyBaseException):
65+
def __init__(self, field: str):
66+
self.message = f"Field {field} already exists."
67+
super().__init__(self.message)
68+
69+
70+
class InvalidArgument(AggifyBaseException):
71+
def __init__(self, expected_list: list):
72+
self.message = f"Input is not correctly passed, expected {[expected for expected in expected_list]}"
73+
self.expecteds = expected_list
74+
super().__init__(self.message)

aggify/utilty.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any
1+
from typing import Any, Type
22

33
from mongoengine import Document
44

5-
from aggify.exceptions import MongoIndexError, InvalidField
5+
from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField
66

77

88
def int_to_slice(final_index: int) -> slice:
@@ -113,3 +113,36 @@ def convert_match_query(
113113
return [convert_match_query(item) for item in d]
114114
else:
115115
return d
116+
117+
118+
def check_field_exists(model: Type[Document], field: str) -> None:
119+
"""
120+
Check if a field exists in the given model.
121+
122+
Args:
123+
model (Document): The model to check for the field.
124+
field (str): The name of the field to check.
125+
126+
Raises:
127+
AlreadyExistsField: If the field already exists in the model.
128+
"""
129+
if model._fields.get(field): # noqa
130+
raise AlreadyExistsField(field=field)
131+
132+
133+
def get_db_field(model: Type[Document], field: str) -> str:
134+
"""
135+
Get the database field name for a given field in the model.
136+
137+
Args:
138+
model (Document): The model containing the field.
139+
field (str): The name of the field.
140+
141+
Returns:
142+
str: The database field name if available, otherwise the original field name.
143+
"""
144+
try:
145+
db_field = model._fields.get(field).db_field # noqa
146+
return field if db_field is None else db_field
147+
except AttributeError:
148+
return field

0 commit comments

Comments
 (0)