diff --git a/graphicle/data.py b/graphicle/data.py index 1477cd4..c5f183c 100644 --- a/graphicle/data.py +++ b/graphicle/data.py @@ -914,7 +914,7 @@ def recursive_drop( return mask_group def flatten( - self, how: ty.Literal["rise", "agg"] = "rise" + self, how: ty.Literal["rise", "agg", "leaves"] = "rise" ) -> "MaskGroup[MaskArray]": """Removes nesting such that the ``MaskGroup`` contains only ``MaskArray`` instances, and no other ``MaskGroup``. @@ -922,18 +922,23 @@ def flatten( .. versionadded:: 0.1.11 .. versionchanged:: 0.2.6 - Added ``'how'`` parameter. + Added ``how`` parameter. + + .. versionchanged:: 0.3.7 + Added ``leaves`` option for ``how`` parameter. Parameters ---------- - how : {'rise', 'agg'} - Method used to convert into flat ``MaskGroup``. ``'rise'`` + how : {'rise', 'agg', 'leaves'} + Method used to convert into flat ``MaskGroup``. ``rise`` recurses through nested levels, raising all contained - ``MaskArray`` instances to the top level. ``'agg'`` loops - over the top level of ``MaskBase`` objects, leaving - top-level ``MaskArray`` objects as-is, but calling the - aggregation operation over any ``MaskGroup``. Default is - ``'rise'``. + ``MaskArray`` instances to the top level. + ``agg`` loops over the top level of ``MaskBase`` objects, + leaving top-level ``MaskArray`` objects as-is, but calling + the aggregation operation over any ``MaskGroup``. + ``leaves`` brings the innermosted nested ``MaskArray`` + instances to the top level, discarding the rest. + Default is``'rise'``. Returns ------- @@ -942,7 +947,19 @@ def flatten( at the top level. """ - def leaves( + if how == "leaves": + from graphicle.select import leaf_masks + + return leaf_masks(self) + if how == "agg": + return self.__class__( + cl.OrderedDict( + zip(self.keys(), map(op.attrgetter("data"), self.values())) + ), + "or", + ) + + def visit( mask_group: "MaskGroup", ) -> ty.Iterator[ty.Tuple[str, base.MaskLike]]: for key, val in mask_group.items(): @@ -950,18 +967,11 @@ def leaves( continue if isinstance(val, type(self)): yield key, val.data - yield from leaves(val) + yield from visit(val) else: yield key, val - if how == "rise": - return self.__class__(cl.OrderedDict(leaves(self)), "or") # type: ignore - return self.__class__( - cl.OrderedDict( - zip(self.keys(), map(op.attrgetter("data"), self.values())) - ), - "or", - ) + return self.__class__(cl.OrderedDict(visit(self)), self.agg_op) def serialize(self) -> ty.Dict[str, ty.Any]: """Returns serialized data as a dictionary.