Skip to content

Commit 164e6e1

Browse files
Merge pull request #28 from mohamadkhalaj/main
Fix issue #26
2 parents bc2e29d + 5a47462 commit 164e6e1

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

aggify/aggify.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,48 @@ def __init__(self, base_model: Type[Document]):
6363

6464
@last_out_stage_check
6565
def project(self, **kwargs: QueryParams) -> "Aggify":
66+
"""
67+
Adjusts the base model's fields based on the given keyword arguments.
68+
69+
Fields to be retained are set to 1 in kwargs.
70+
Fields to be deleted are set to 0 in kwargs, except for _id which is controlled by the delete_id flag.
71+
72+
Args:
73+
**kwargs: Fields to be retained or removed.
74+
For example: {"field1": 1, "field2": 0}
75+
_id field behavior: {"_id": 0} means delete _id.
76+
77+
Returns:
78+
Aggify: Returns an instance of the Aggify class for potential method chaining.
79+
"""
80+
81+
# Extract fields to keep and check if _id should be deleted
82+
to_keep_values = ["id"]
83+
delete_id = kwargs.get("_id") == 0
84+
85+
# Add missing fields to the base model
86+
for key, value in kwargs.items():
87+
if value == 1:
88+
to_keep_values.append(key)
89+
elif key not in self.base_model._fields and isinstance( # noqa
90+
kwargs[key], str
91+
): # noqa
92+
to_keep_values.append(key)
93+
self.base_model._fields[key] = fields.IntField() # noqa
94+
95+
# Remove fields from the base model, except the ones in to_keep_values and possibly _id
96+
keys_for_deletion = set(self.base_model._fields.keys()) - set( # noqa
97+
to_keep_values
98+
) # noqa
99+
if delete_id:
100+
keys_for_deletion.add("id")
101+
for key in keys_for_deletion:
102+
del self.base_model._fields[key] # noqa
103+
104+
# Append the projection stage to the pipelines
66105
self.pipelines.append({"$project": kwargs})
106+
107+
# Return the instance for method chaining
67108
return self
68109

69110
@last_out_stage_check
@@ -87,7 +128,7 @@ def raw(self, raw_query: dict) -> "Aggify":
87128
return self
88129

89130
@last_out_stage_check
90-
def add_fields(self, **fields) -> "Aggify": # noqa
131+
def add_fields(self, **_fields) -> "Aggify": # noqa
91132
"""
92133
Generates a MongoDB addFields pipeline stage.
93134
@@ -99,7 +140,8 @@ def add_fields(self, **fields) -> "Aggify": # noqa
99140
"""
100141
add_fields_stage = {"$addFields": {}}
101142

102-
for field, expression in fields.items():
143+
for field, expression in _fields.items():
144+
field = field.replace("__", ".")
103145
if isinstance(expression, str):
104146
add_fields_stage["$addFields"][field] = {"$literal": expression}
105147
elif isinstance(expression, F):
@@ -108,6 +150,8 @@ def add_fields(self, **fields) -> "Aggify": # noqa
108150
add_fields_stage["$addFields"][field] = dict(expression)
109151
else:
110152
raise AggifyValueError([str, F], type(expression))
153+
# TODO: Should be checked if new field is embedded, create embedded field.
154+
self.base_model._fields[field.replace("$", "")] = fields.IntField() # noqa
111155

112156
self.pipelines.append(add_fields_stage)
113157
return self

aggify/compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __invert__(self):
115115
class F:
116116
def __init__(self, field: str | dict[str, list]):
117117
if isinstance(field, str):
118-
self.field = f"${field}"
118+
self.field = f"${field.replace('__', '.')}"
119119
else:
120120
self.field = field
121121

0 commit comments

Comments
 (0)