Skip to content

Commit

Permalink
chore: refactor a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
jbarreau committed Dec 18, 2024
1 parent 84634e8 commit 34e72f9
Showing 1 changed file with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union, cast
from typing import Dict, List, Optional, Union, cast

from forestadmin.agent_toolkit.utils.context import User
from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator
Expand All @@ -18,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project
refined_filter = cast(PaginatedFilter, await self._refine_filter(caller, filter_))
ret = await self.child_collection.list(caller, refined_filter, simplified_projection)

return self._apply_joins_on_records(projection, simplified_projection, ret)
return self._apply_joins_on_simplified_records(projection, simplified_projection, ret)

async def _refine_filter(
self, caller: User, _filter: Union[Filter, PaginatedFilter, None]
Expand All @@ -29,11 +29,11 @@ async def _refine_filter(
_filter.condition_tree = _filter.condition_tree.replace(
lambda leaf: (
ConditionTreeLeaf(
self._get_fk_field_for_projection(leaf.field),
self._get_fk_field_for_many_to_one_projection(leaf.field),
leaf.operator,
leaf.value,
)
if self._is_useless_join(leaf.field.split(":")[0], _filter.condition_tree.projection)
if self._is_useless_join_for_projection(leaf.field.split(":")[0], _filter.condition_tree.projection)
else leaf
)
)
Expand All @@ -43,36 +43,25 @@ async def _refine_filter(
async def aggregate(
self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None
) -> List[AggregateResult]:
replaced = {}
replaced = {} # new_name -> old_name; for a simpler reconciliation

def replacer(field_name: str) -> str:
if self._is_useless_join(field_name.split(":")[0], aggregation.projection):
new_field_name = self._get_fk_field_for_projection(field_name)
if self._is_useless_join_for_projection(field_name.split(":")[0], aggregation.projection):
new_field_name = self._get_fk_field_for_many_to_one_projection(field_name)
replaced[new_field_name] = field_name
return new_field_name
else:
return field_name
return field_name

new_aggregation = aggregation.replace_fields(replacer)

aggregate_result = await self.child_collection.aggregate(
aggregate_results = await self.child_collection.aggregate(
caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit
)
if aggregation == new_aggregation:
return aggregate_result
return aggregate_results
return self._replace_fields_in_aggregate_group(aggregate_results, replaced)

for result in aggregate_result:
group = {}
for field, value in result["group"].items():
if field in replaced:
group[replaced[field]] = value
else:
group[field] = value
result["group"] = group

return aggregate_result

def _is_useless_join(self, relation: str, projection: Projection) -> bool:
def _is_useless_join_for_projection(self, relation: str, projection: Projection) -> bool:
relation_schema = self.schema["fields"][relation]
sub_projections = projection.relations[relation]

Expand All @@ -82,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool:
and sub_projections[0] == relation_schema["foreign_key_target"]
)

def _get_fk_field_for_projection(self, projection: str) -> str:
def _get_fk_field_for_many_to_one_projection(self, projection: str) -> str:
relation_name = projection.split(":")[0]
relation_schema = cast(ManyToOne, self.schema["fields"][relation_name])

Expand All @@ -91,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str:
def _get_projection_without_useless_joins(self, projection: Projection) -> Projection:
returned_projection = Projection(*projection)
for relation, relation_projections in projection.relations.items():
if self._is_useless_join(relation, projection):
if self._is_useless_join_for_projection(relation, projection):
# remove foreign key target from projection
returned_projection.remove(f"{relation}:{relation_projections[0]}")

# add foreign keys to projection
fk_field = self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")
fk_field = self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
if fk_field not in returned_projection:
returned_projection.append(fk_field)

return returned_projection

def _apply_joins_on_records(
def _apply_joins_on_simplified_records(
self, initial_projection: Projection, requested_projection: Projection, records: List[RecordsDataAlias]
) -> List[RecordsDataAlias]:
if requested_projection == initial_projection:
Expand All @@ -117,11 +106,27 @@ def _apply_joins_on_records(
relation_schema = self.schema["fields"][relation]

if is_many_to_one(relation_schema):
fk_value = record[self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")]
fk_value = record[
self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
]
record[relation] = {relation_projections[0]: fk_value} if fk_value else None

# remove foreign keys
for projection in projections_to_rm:
del record[projection]

return records

def _replace_fields_in_aggregate_group(
self, aggregate_results: List[AggregateResult], field_to_replace: Dict[str, str]
) -> List[AggregateResult]:
for aggregate_result in aggregate_results:
group = {}
for field, value in aggregate_result["group"].items():
if field in field_to_replace:
group[field_to_replace[field]] = value
else:
group[field] = value
aggregate_result["group"] = group

return aggregate_results

0 comments on commit 34e72f9

Please sign in to comment.