@@ -590,49 +590,60 @@ def _make_unwind_project_stage(only: list):
590590 }
591591
592592 @classmethod
593- def _stat_with_unwind (
593+ def _stat_with_pipeline (
594594 cls ,
595- unwind : list ,
595+ lookup : list = None ,
596+ unwind : dict = None ,
597+ add_fields : dict = None ,
596598 only : list = None ,
597599 filter : list = None ,
598600 filter_or : list = None ,
599601 sort : list = None ,
600602 page : dict = None ,
601603 target : str = None ,
602604 ):
603- if only is None :
604- raise ERROR_DB_QUERY (reason = "unwind option requires only option." )
605+ if unwind :
606+ if only is None :
607+ raise ERROR_DB_QUERY (reason = "unwind option requires only option." )
605608
606- if not isinstance (unwind , dict ):
607- raise ERROR_DB_QUERY (reason = "unwind option should be dict type." )
609+ if not isinstance (unwind , dict ):
610+ raise ERROR_DB_QUERY (reason = "unwind option should be dict type." )
608611
609- if "path" not in unwind :
610- raise ERROR_DB_QUERY (reason = "unwind option should have path key." )
612+ if "path" not in unwind :
613+ raise ERROR_DB_QUERY (reason = "unwind option should have path key." )
611614
612- unwind_path = unwind ["path" ]
613- aggregate = [{"unwind" : unwind }]
615+ aggregate = []
614616
615- # Add project stage
616- project_fields = []
617- for key in only :
618- project_fields .append (
617+ if lookup :
618+ for lu in lookup :
619+ aggregate .append ({"lookup" : lu })
620+
621+ if unwind :
622+ aggregate .append ({"unwind" : unwind })
623+
624+ if add_fields :
625+ aggregate .append ({"add_fields" : add_fields })
626+
627+ if only :
628+ project_fields = []
629+ for key in only :
630+ project_fields .append (
631+ {
632+ "key" : key ,
633+ "name" : key ,
634+ }
635+ )
636+
637+ aggregate .append (
619638 {
620- "key" : key ,
621- "name" : key ,
639+ "project" : {
640+ "exclude_keys" : True ,
641+ "only_keys" : True ,
642+ "fields" : project_fields ,
643+ }
622644 }
623645 )
624646
625- aggregate .append (
626- {
627- "project" : {
628- "exclude_keys" : True ,
629- "only_keys" : True ,
630- "fields" : project_fields ,
631- }
632- }
633- )
634-
635- # Add sort stage
636647 if sort :
637648 aggregate .append ({"sort" : sort })
638649
@@ -641,21 +652,23 @@ def _stat_with_unwind(
641652 filter = filter ,
642653 filter_or = filter_or ,
643654 page = page ,
644- tageet = target ,
655+ target = target ,
645656 allow_disk_use = True ,
646657 )
647658
648659 try :
649660 vos = []
650661 total_count = response .get ("total_count" , 0 )
651662 for result in response .get ("results" , []):
652- unwind_data = utils .get_dict_value (result , unwind_path )
653- result = utils .change_dict_value (result , unwind_path , [unwind_data ])
663+ if unwind :
664+ unwind_path = unwind ["path" ]
665+ unwind_data = utils .get_dict_value (result , unwind_path )
666+ result = utils .change_dict_value (result , unwind_path , [unwind_data ])
654667
655668 vo = cls (** result )
656669 vos .append (vo )
657670 except Exception as e :
658- raise ERROR_DB_QUERY (reason = f"Failed to convert unwind result: { e } " )
671+ raise ERROR_DB_QUERY (reason = f"Failed to convert pipeline result: { e } " )
659672
660673 return vos , total_count
661674
@@ -672,7 +685,9 @@ def query(
672685 minimal = False ,
673686 include_count = True ,
674687 count_only = False ,
688+ lookup = None ,
675689 unwind = None ,
690+ add_fields = None ,
676691 reference_filter = None ,
677692 target = None ,
678693 hint = None ,
@@ -683,9 +698,17 @@ def query(
683698 sort = sort or []
684699 page = page or {}
685700
686- if unwind :
687- return cls ._stat_with_unwind (
688- unwind , only , filter , filter_or , sort , page , target
701+ if unwind or lookup or add_fields :
702+ return cls ._stat_with_pipeline (
703+ lookup = lookup ,
704+ unwind = unwind ,
705+ add_fields = add_fields ,
706+ only = only ,
707+ filter = filter ,
708+ filter_or = filter_or ,
709+ sort = sort ,
710+ page = page ,
711+ target = target ,
689712 )
690713
691714 else :
@@ -1075,6 +1098,44 @@ def _make_match_rule(cls, options):
10751098
10761099 return {"$match" : match_options }
10771100
1101+ @classmethod
1102+ def _make_lookup_rule (cls , options ):
1103+ return {"$lookup" : options }
1104+
1105+ @classmethod
1106+ def _make_add_fields_rule (cls , options ):
1107+ add_fields_options = {}
1108+
1109+ for field , conditional in options .items ():
1110+ add_fields_options .update (
1111+ {field : cls ._process_conditional_expression (conditional )}
1112+ )
1113+
1114+ return {"$addFields" : add_fields_options }
1115+
1116+ @classmethod
1117+ def _process_conditional_expression (cls , expression ):
1118+ if isinstance (expression , dict ):
1119+ if_expression = expression ["if" ]
1120+
1121+ if isinstance (if_expression , dict ):
1122+ replaced = {}
1123+ for k , v in if_expression .items ():
1124+ new_k = k .replace ("__" , "$" )
1125+ replaced [new_k ] = v
1126+
1127+ if_expression = replaced
1128+
1129+ return {
1130+ "$cond" : {
1131+ "if" : if_expression ,
1132+ "then" : cls ._process_conditional_expression (expression ["then" ]),
1133+ "else" : cls ._process_conditional_expression (expression ["else" ]),
1134+ }
1135+ }
1136+
1137+ return expression
1138+
10781139 @classmethod
10791140 def _make_aggregate_rules (cls , aggregate ):
10801141 _aggregate_rules = []
@@ -1116,6 +1177,12 @@ def _make_aggregate_rules(cls, aggregate):
11161177 elif "match" in stage :
11171178 rule = cls ._make_match_rule (stage ["match" ])
11181179 _aggregate_rules .append (rule )
1180+ elif "lookup" in stage :
1181+ rule = cls ._make_lookup_rule (stage ["lookup" ])
1182+ _aggregate_rules .append (rule )
1183+ elif "add_fields" in stage :
1184+ rule = cls ._make_add_fields_rule (stage ["add_fields" ])
1185+ _aggregate_rules .append (rule )
11191186 else :
11201187 raise ERROR_REQUIRED_PARAMETER (
11211188 key = "aggregate.unwind or aggregate.group or "
@@ -1514,7 +1581,9 @@ def analyze(
15141581 sort = None ,
15151582 start = None ,
15161583 end = None ,
1584+ lookup = None ,
15171585 unwind = None ,
1586+ add_fields = None ,
15181587 date_field = "date" ,
15191588 date_field_format = "%Y-%m-%d" ,
15201589 reference_filter = None ,
@@ -1552,9 +1621,16 @@ def analyze(
15521621
15531622 aggregate = []
15541623
1624+ if lookup :
1625+ for lu in lookup :
1626+ aggregate .append ({"lookup" : lu })
1627+
15551628 if unwind :
15561629 aggregate .append ({"unwind" : unwind })
15571630
1631+ if add_fields :
1632+ aggregate .append ({"add_fields" : add_fields })
1633+
15581634 aggregate .append ({"group" : {"keys" : group_keys , "fields" : group_fields }})
15591635
15601636 query = {
0 commit comments