Skip to content

Commit

Permalink
added "leaves" literal option to MaskGroup.flatten how #163
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Oct 11, 2023
1 parent 86ebc39 commit 5e607c2
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,26 +914,31 @@ def recursive_drop(
return mask_group

def flatten(
self, how: ty.Literal["rise", "agg"] = "rise"
self, how: ty.Literal["rise", "agg", "leaves"] = "rise"
) -> "MaskGroup[MaskArray]":
"""Removes nesting such that the ``MaskGroup`` contains only
``MaskArray`` instances, and no other ``MaskGroup``.
.. versionadded:: 0.1.11
.. versionchanged:: 0.2.6
Added ``'how'`` parameter.
Added ``how`` parameter.
.. versionchanged:: 0.3.7
Added ``leaves`` option for ``how`` parameter.
Parameters
----------
how : {'rise', 'agg'}
Method used to convert into flat ``MaskGroup``. ``'rise'``
how : {'rise', 'agg', 'leaves'}
Method used to convert into flat ``MaskGroup``. ``rise``
recurses through nested levels, raising all contained
``MaskArray`` instances to the top level. ``'agg'`` loops
over the top level of ``MaskBase`` objects, leaving
top-level ``MaskArray`` objects as-is, but calling the
aggregation operation over any ``MaskGroup``. Default is
``'rise'``.
``MaskArray`` instances to the top level.
``agg`` loops over the top level of ``MaskBase`` objects,
leaving top-level ``MaskArray`` objects as-is, but calling
the aggregation operation over any ``MaskGroup``.
``leaves`` brings the innermosted nested ``MaskArray``
instances to the top level, discarding the rest.
Default is``'rise'``.
Returns
-------
Expand All @@ -942,26 +947,31 @@ def flatten(
at the top level.
"""

def leaves(
if how == "leaves":
from graphicle.select import leaf_masks

return leaf_masks(self)
if how == "agg":
return self.__class__(
cl.OrderedDict(
zip(self.keys(), map(op.attrgetter("data"), self.values()))
),
"or",
)

def visit(
mask_group: "MaskGroup",
) -> ty.Iterator[ty.Tuple[str, base.MaskLike]]:
for key, val in mask_group.items():
if key == "latent":
continue
if isinstance(val, type(self)):
yield key, val.data
yield from leaves(val)
yield from visit(val)
else:
yield key, val

if how == "rise":
return self.__class__(cl.OrderedDict(leaves(self)), "or") # type: ignore
return self.__class__(
cl.OrderedDict(
zip(self.keys(), map(op.attrgetter("data"), self.values()))
),
"or",
)
return self.__class__(cl.OrderedDict(visit(self)), self.agg_op)

def serialize(self) -> ty.Dict[str, ty.Any]:
"""Returns serialized data as a dictionary.
Expand Down

0 comments on commit 5e607c2

Please sign in to comment.