Skip to content

Commit

Permalink
leaf_masks now inherits agg_op from mask_tree #163
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Oct 11, 2023
1 parent b72b459 commit 86ebc39
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions graphicle/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,7 @@ def _leaf_mask_iter(
for name, mask in branch.items():
if exclude_latent and name == "latent":
continue
# TODO: look into contravariant type for this
yield from _leaf_mask_iter(name, mask, exclude_latent) # type: ignore
yield from _leaf_mask_iter(name, mask, exclude_latent)


def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]:
Expand All @@ -882,6 +881,9 @@ def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]:
.. versionadded:: 0.1.11
.. versionchanged 0.3.7
Output ``MaskGroup`` matches agg_op of ``mask_tree``.
Parameters
----------
mask_tree : MaskGroup
Expand All @@ -893,7 +895,7 @@ def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]:
MaskGroup
Flat ``MaskGroup`` of only the leaves of ``mask_tree``.
"""
mask_group = gcl.MaskGroup(agg_op="or") # type: ignore
mask_group = gcl.MaskGroup(agg_op=mask_tree.agg_op)
for name, branch in mask_tree.items():
mask_group.update(dict(_leaf_mask_iter(name, branch))) # type: ignore
return mask_group
Expand Down

0 comments on commit 86ebc39

Please sign in to comment.