1
1
import functools
2
- from typing import Any , Literal , Dict , Type
2
+ from typing import Any , Dict , Type
3
3
4
4
from mongoengine import Document , EmbeddedDocument , fields
5
+ from mongoengine .base import TopLevelDocumentMetaclass
5
6
6
7
from aggify .compiler import F , Match , Q , Operators # noqa keep
7
8
from aggify .exceptions import (
10
11
InvalidField ,
11
12
InvalidEmbeddedField ,
12
13
OutStageError ,
14
+ InvalidArgument ,
13
15
)
14
16
from aggify .types import QueryParams
15
17
from aggify .utilty import (
16
18
to_mongo_positive_index ,
17
19
check_fields_exist ,
18
20
replace_values_recursive ,
19
21
convert_match_query ,
22
+ check_field_exists ,
23
+ get_db_field ,
20
24
)
21
25
22
26
@@ -174,12 +178,17 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
174
178
split_query = key .split ("__" )
175
179
176
180
# Retrieve the field definition from the model.
177
- join_field = self .get_model_field (self .base_model , split_query [0 ]) # type: ignore # noqa
178
-
181
+ join_field = self .get_model_field (self .base_model , split_query [0 ]) # type: ignore
179
182
# Check conditions for creating a 'match' pipeline stage.
180
183
if (
181
- "document_type_obj" not in join_field .__dict__
182
- or issubclass (join_field .document_type , EmbeddedDocument )
184
+ isinstance (
185
+ join_field , TopLevelDocumentMetaclass
186
+ ) # check whether field is added by lookup stage or not
187
+ or "document_type_obj"
188
+ not in join_field .__dict__ # Check whether this field is a join field or not.
189
+ or issubclass (
190
+ join_field .document_type , EmbeddedDocument
191
+ ) # Check whether this field is embedded field or not
183
192
or len (split_query ) == 1
184
193
or (len (split_query ) == 2 and split_query [1 ] in Operators .ALL_OPERATORS )
185
194
):
@@ -191,7 +200,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
191
200
self .pipelines .append (match )
192
201
193
202
else :
194
- from_collection = join_field .document_type # noqa
203
+ from_collection = join_field .document_type
195
204
local_field = join_field .db_field
196
205
as_name = join_field .name
197
206
matches = []
@@ -210,7 +219,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
210
219
as_name = as_name ,
211
220
)
212
221
)
213
- self .unwind (as_name )
222
+ self .unwind (as_name , preserve = True )
214
223
self .pipelines .extend ([{"$match" : match } for match in matches ])
215
224
216
225
@last_out_stage_check
@@ -356,7 +365,7 @@ def annotate(
356
365
else :
357
366
if isinstance (f , str ):
358
367
try :
359
- self .get_model_field (self .base_model , f ) # noqa
368
+ self .get_model_field (self .base_model , f )
360
369
value = f"${ f } "
361
370
except InvalidField :
362
371
value = f
@@ -429,66 +438,94 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]:
429
438
430
439
@last_out_stage_check
431
440
def lookup (
432
- self , from_collection : Document , let : list [str ], query : list [Q ], as_name : str
441
+ self ,
442
+ from_collection : Document ,
443
+ as_name : str ,
444
+ query : list [Q ] | Q | None = None ,
445
+ let : list [str ] | None = None ,
446
+ local_field : str | None = None ,
447
+ foreign_field : str | None = None ,
433
448
) -> "Aggify" :
434
449
"""
435
450
Generates a MongoDB lookup pipeline stage.
436
451
437
452
Args:
438
- from_collection (Document): The name of the collection to lookup.
439
- let (list): The local field(s) to join on.
440
- query (list[Q]): List of desired queries with Q function.
453
+ from_collection (Document): The document representing the collection to perform the lookup on.
441
454
as_name (str): The name of the new field to create.
455
+ query (list[Q] | Q | None, optional): List of desired queries with Q function or a single query.
456
+ let (list[str] | None, optional): The local field(s) to join on. If provided, localField and foreignField are not used.
457
+ local_field (str | None, optional): The local field to join on when let is not provided.
458
+ foreign_field (str | None, optional): The foreign field to join on when let is not provided.
442
459
443
460
Returns:
444
- Aggify: A MongoDB lookup pipeline stage.
461
+ Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
445
462
"""
446
- check_fields_exist (self .base_model , let ) # noqa
447
-
448
- let_dict = {
449
- field : f"${ self .base_model ._fields [field ].db_field } "
450
- for field in let # noqa
451
- }
452
- from_collection = from_collection ._meta .get ("collection" ) # noqa
453
463
454
464
lookup_stages = []
465
+ check_field_exists (self .base_model , as_name )
466
+ from_collection_name = from_collection ._meta .get ("collection" ) # noqa
455
467
456
- for q in query :
457
- # Construct the match stage for each query
458
- if isinstance (q , Q ):
459
- replaced_values = replace_values_recursive (
460
- convert_match_query (dict (q )), # noqa
461
- {field : f"$${ field } " for field in let },
462
- )
463
- match_stage = {"$match" : {"$expr" : replaced_values .get ("$match" )}}
464
- lookup_stages .append (match_stage )
465
- elif isinstance (q , Aggify ):
466
- lookup_stages .extend (
467
- replace_values_recursive (
468
- convert_match_query (q .pipelines ), # noqa
468
+ if not let and not (local_field and foreign_field ):
469
+ raise InvalidArgument (
470
+ expected_list = [["local_field" , "foreign_field" ], "let" ]
471
+ )
472
+ elif not let :
473
+ if not (local_field and foreign_field ):
474
+ raise InvalidArgument (expected_list = ["local_field" , "foreign_field" ])
475
+ lookup_stage = {
476
+ "$lookup" : {
477
+ "from" : from_collection_name ,
478
+ "localField" : get_db_field (self .base_model , local_field ), # noqa
479
+ "foreignField" : get_db_field (
480
+ from_collection , foreign_field
481
+ ), # noqa
482
+ "as" : as_name ,
483
+ }
484
+ }
485
+ else :
486
+ if not query :
487
+ raise InvalidArgument (expected_list = ["query" ])
488
+ check_fields_exist (self .base_model , let ) # noqa
489
+ let_dict = {
490
+ field : f"${ get_db_field (self .base_model , field )} "
491
+ for field in let # noqa
492
+ }
493
+ for q in query :
494
+ # Construct the match stage for each query
495
+ if isinstance (q , Q ):
496
+ replaced_values = replace_values_recursive (
497
+ convert_match_query (dict (q )),
469
498
{field : f"$${ field } " for field in let },
470
499
)
471
- )
500
+ match_stage = {"$match" : {"$expr" : replaced_values .get ("$match" )}}
501
+ lookup_stages .append (match_stage )
502
+ elif isinstance (q , Aggify ):
503
+ lookup_stages .extend (
504
+ replace_values_recursive (
505
+ convert_match_query (q .pipelines ), # noqa
506
+ {field : f"$${ field } " for field in let },
507
+ )
508
+ )
472
509
473
- # Append the lookup stage with multiple match stages to the pipeline
474
- lookup_stage = {
475
- "$lookup" : {
476
- "from" : from_collection ,
477
- "let" : let_dict ,
478
- "pipeline" : lookup_stages , # List of match stages
479
- "as" : as_name ,
510
+ # Append the lookup stage with multiple match stages to the pipeline
511
+ lookup_stage = {
512
+ "$lookup" : {
513
+ "from" : from_collection_name ,
514
+ "let" : let_dict ,
515
+ "pipeline" : lookup_stages , # List of match stages
516
+ "as" : as_name ,
517
+ }
480
518
}
481
- }
482
519
483
520
self .pipelines .append (lookup_stage )
484
521
485
522
# Add this new field to base model fields, which we can use it in the next stages.
486
- self .base_model ._fields [as_name ] = fields . StringField () # noqa
523
+ self .base_model ._fields [as_name ] = from_collection # noqa
487
524
488
525
return self
489
526
490
527
@staticmethod
491
- def get_model_field (model : Document , field : str ) -> fields :
528
+ def get_model_field (model : Type [ Document ] , field : str ) -> fields :
492
529
"""
493
530
Get the field definition of a specified field in a MongoDB model.
494
531
@@ -520,7 +557,7 @@ def _replace_base(self, embedded_field) -> str:
520
557
Raises:
521
558
InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type.
522
559
"""
523
- model_field = self .get_model_field (self .base_model , embedded_field ) # noqa
560
+ model_field = self .get_model_field (self .base_model , embedded_field )
524
561
525
562
if not hasattr (model_field , "document_type" ) or not issubclass (
526
563
model_field .document_type , EmbeddedDocument
0 commit comments