Skip to content

Commit 6600a46

Browse files
Merge pull request #19 from mohamadkhalaj/main
Fix issue #15
2 parents 1c8f7e7 + 1837ce0 commit 6600a46

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

aggify/aggify.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from mongoengine import Document, EmbeddedDocument, fields
55

6-
from aggify.compiler import F, Match, Q # noqa keep
6+
from aggify.compiler import F, Match, Q, Operators # noqa keep
77
from aggify.exceptions import AggifyValueError, AnnotationError, InvalidField, InvalidEmbeddedField, OutStageError
88
from aggify.types import QueryParams
99
from aggify.utilty import (
@@ -164,19 +164,28 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
164164
if key in skip_list:
165165
continue
166166

167+
# Split the key to access the field information.
167168
split_query = key.split("__")
168-
join_field = self.base_model._fields.get(split_query[0]) # type: ignore
169-
if not join_field:
170-
raise ValueError(f"Invalid field: {split_query[0]}")
171-
# This is a nested query.
172-
if "document_type_obj" not in join_field.__dict__ or issubclass(
173-
join_field.document_type, EmbeddedDocument
169+
170+
# Retrieve the field definition from the model.
171+
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # noqa
172+
173+
# Check conditions for creating a 'match' pipeline stage.
174+
if (
175+
"document_type_obj" not in join_field.__dict__
176+
or issubclass(join_field.document_type, EmbeddedDocument)
177+
or len(split_query) == 1
178+
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
174179
):
180+
# Create a 'match' pipeline stage.
175181
match = self.__match({key: value})
176-
if (match.get("$match")) != {}:
182+
183+
# Check if the 'match' stage is not empty before adding it to the pipelines.
184+
if match.get("$match"):
177185
self.pipelines.append(match)
186+
178187
else:
179-
from_collection = join_field.document_type._meta["collection"] # noqa
188+
from_collection = join_field.document_type # noqa
180189
local_field = join_field.db_field
181190
as_name = join_field.name
182191
matches = []
@@ -191,7 +200,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
191200
self.pipelines.extend(
192201
[
193202
self.__lookup(
194-
from_collection=from_collection,
203+
from_collection=from_collection._meta["collection"], # noqa
195204
local_field=local_field,
196205
as_name=as_name,
197206
),

0 commit comments

Comments
 (0)