From 86ebc3980c4b641ec5c8ddd9eb7bb2889a5064d0 Mon Sep 17 00:00:00 2001 From: Jacan Chaplais Date: Wed, 11 Oct 2023 19:34:47 +0100 Subject: [PATCH 1/3] leaf_masks now inherits agg_op from mask_tree #163 --- graphicle/select.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/graphicle/select.py b/graphicle/select.py index 22b3edc..3661c64 100644 --- a/graphicle/select.py +++ b/graphicle/select.py @@ -870,8 +870,7 @@ def _leaf_mask_iter( for name, mask in branch.items(): if exclude_latent and name == "latent": continue - # TODO: look into contravariant type for this - yield from _leaf_mask_iter(name, mask, exclude_latent) # type: ignore + yield from _leaf_mask_iter(name, mask, exclude_latent) def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]: @@ -882,6 +881,9 @@ def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]: .. versionadded:: 0.1.11 + .. versionchanged 0.3.7 + Output ``MaskGroup`` matches agg_op of ``mask_tree``. + Parameters ---------- mask_tree : MaskGroup @@ -893,7 +895,7 @@ def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]: MaskGroup Flat ``MaskGroup`` of only the leaves of ``mask_tree``. """ - mask_group = gcl.MaskGroup(agg_op="or") # type: ignore + mask_group = gcl.MaskGroup(agg_op=mask_tree.agg_op) for name, branch in mask_tree.items(): mask_group.update(dict(_leaf_mask_iter(name, branch))) # type: ignore return mask_group From 5e607c2d8bc199de4e4ba53c1dd2d1767dd7181d Mon Sep 17 00:00:00 2001 From: Jacan Chaplais Date: Wed, 11 Oct 2023 19:35:38 +0100 Subject: [PATCH 2/3] added "leaves" literal option to MaskGroup.flatten how #163 --- graphicle/data.py | 48 ++++++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/graphicle/data.py b/graphicle/data.py index 1477cd4..c5f183c 100644 --- a/graphicle/data.py +++ b/graphicle/data.py @@ -914,7 +914,7 @@ def recursive_drop( return mask_group def flatten( - self, how: ty.Literal["rise", "agg"] = "rise" + self, how: ty.Literal["rise", "agg", "leaves"] = "rise" ) -> "MaskGroup[MaskArray]": """Removes nesting such that the ``MaskGroup`` contains only ``MaskArray`` instances, and no other ``MaskGroup``. @@ -922,18 +922,23 @@ def flatten( .. versionadded:: 0.1.11 .. versionchanged:: 0.2.6 - Added ``'how'`` parameter. + Added ``how`` parameter. + + .. versionchanged:: 0.3.7 + Added ``leaves`` option for ``how`` parameter. Parameters ---------- - how : {'rise', 'agg'} - Method used to convert into flat ``MaskGroup``. ``'rise'`` + how : {'rise', 'agg', 'leaves'} + Method used to convert into flat ``MaskGroup``. ``rise`` recurses through nested levels, raising all contained - ``MaskArray`` instances to the top level. ``'agg'`` loops - over the top level of ``MaskBase`` objects, leaving - top-level ``MaskArray`` objects as-is, but calling the - aggregation operation over any ``MaskGroup``. Default is - ``'rise'``. + ``MaskArray`` instances to the top level. + ``agg`` loops over the top level of ``MaskBase`` objects, + leaving top-level ``MaskArray`` objects as-is, but calling + the aggregation operation over any ``MaskGroup``. + ``leaves`` brings the innermosted nested ``MaskArray`` + instances to the top level, discarding the rest. + Default is``'rise'``. Returns ------- @@ -942,7 +947,19 @@ def flatten( at the top level. """ - def leaves( + if how == "leaves": + from graphicle.select import leaf_masks + + return leaf_masks(self) + if how == "agg": + return self.__class__( + cl.OrderedDict( + zip(self.keys(), map(op.attrgetter("data"), self.values())) + ), + "or", + ) + + def visit( mask_group: "MaskGroup", ) -> ty.Iterator[ty.Tuple[str, base.MaskLike]]: for key, val in mask_group.items(): @@ -950,18 +967,11 @@ def leaves( continue if isinstance(val, type(self)): yield key, val.data - yield from leaves(val) + yield from visit(val) else: yield key, val - if how == "rise": - return self.__class__(cl.OrderedDict(leaves(self)), "or") # type: ignore - return self.__class__( - cl.OrderedDict( - zip(self.keys(), map(op.attrgetter("data"), self.values())) - ), - "or", - ) + return self.__class__(cl.OrderedDict(visit(self)), self.agg_op) def serialize(self) -> ty.Dict[str, ty.Any]: """Returns serialized data as a dictionary. From d0c43c9fd14abeb18ec51a07c1a57161af5947a8 Mon Sep 17 00:00:00 2001 From: Jacan Chaplais Date: Wed, 11 Oct 2023 19:44:52 +0100 Subject: [PATCH 3/3] formatting issue #163 --- graphicle/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphicle/data.py b/graphicle/data.py index c5f183c..abd9f71 100644 --- a/graphicle/data.py +++ b/graphicle/data.py @@ -938,7 +938,7 @@ def flatten( the aggregation operation over any ``MaskGroup``. ``leaves`` brings the innermosted nested ``MaskArray`` instances to the top level, discarding the rest. - Default is``'rise'``. + Default is ``rise``. Returns -------