diff --git a/graphicle/data.py b/graphicle/data.py index abd9f71..2cbc334 100644 --- a/graphicle/data.py +++ b/graphicle/data.py @@ -642,7 +642,9 @@ def _mask_neq(mask1: base.MaskLike, mask2: base.MaskLike) -> MaskArray: return MaskArray(np.not_equal(mask1, mask2)) -_IN_MASK_DICT = ty.OrderedDict[str, ty.Union[MaskArray, base.BoolVector]] +_IN_MASK_DICT = ty.Mapping[ + str, ty.Union[MaskArray, base.BoolVector, ty.Iterable[bool]] +] _MASK_DICT = ty.OrderedDict[str, MaskArray] @@ -651,12 +653,32 @@ def _mask_dict_convert(masks: _IN_MASK_DICT) -> _MASK_DICT: for key, val in masks.items(): if isinstance(val, MaskArray) or isinstance(val, MaskGroup): mask = val + elif isinstance(val, cla.Mapping): + mask = MaskGroup(_mask_dict_convert(val)) else: mask = MaskArray(val) out_masks[key] = mask return out_masks +def _maskgroup_equal( + group_1: "MaskGroup", group_2: "MaskGroup", check_order: bool +) -> bool: + key_struct = tuple if check_order else set + if key_struct(group_1) != key_struct(group_2): + return False + for key in group_1: + mask_1, mask_2 = group_1[key], group_2[key] + if type(mask_1) != type(mask_2): + return False + if isinstance(mask_1, MaskGroup): + if not _maskgroup_equal(mask_1, mask_2, check_order): + return False + elif not np.array_equal(mask_1.data, mask_2.data): + return False + return True + + class MaskAggOp(Enum): AND = "and" OR = "or" @@ -980,6 +1002,29 @@ def serialize(self) -> ty.Dict[str, ty.Any]: """ return {key: val.serialize() for key, val in self._mask_arrays.items()} + def equal_to(self, other: "MaskGroup", check_order: bool = False) -> bool: + """Checks whether this instance is identical to ``other`` + ``MaskGroup``, comparing keys at all levels of nesting, and + boolean array data at the leaf level. + + .. versionadded:: 0.3.9 + + Parameters + ---------- + other : MaskGroup + Other instance, against which to compare for equality. + check_order : bool + If ``True``, will check that the ordering of elements is + identical. Default is ``False``. + + Returns + ------- + bool + ``True`` if instance is identical to ``other``, ``False`` + otherwise. + """ + return _maskgroup_equal(self, other, check_order) + @define(eq=False) class PdgArray(base.ArrayBase): diff --git a/tests/test_data.py b/tests/test_data.py index 07f7943..fce7859 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -9,6 +9,7 @@ import dataclasses as dc import math import random +import string import numpy as np import pytest @@ -18,6 +19,13 @@ ZERO_TOL = 1.0e-10 # absolute tolerance for detecting zero-values +def random_alphanum(length: int) -> str: + return "".join( + random.choice(string.ascii_letters + string.digits) + for _ in range(length) + ) + + @dc.dataclass class MomentumExample: """Dataclass for representing the four-momentum of a particle, based @@ -111,3 +119,57 @@ def test_pmu_zero_pt() -> None: with pytest.warns(gcl.base.NumericalStabilityWarning): phi_invalid = math.isnan(pmu_zero_pt.phi.item()) assert phi_invalid, "Azimuth is not NaN when pT is low" + + +def generate_tree( + max_width: int, max_depth: int, leaf_length: int +) -> gcl.MaskGroup: + """Generates a nested MaskGroup tree structure, with random branch + widths and depths. + + Parameters + ---------- + max_width, max_depth : int + Maximum limits on the nested tree structure. + leaf_length : int + The number of elements in the leaf MaskArrays. + + Returns + ------- + MaskGroup + Tree structure, with random structure, and random MaskArrays at + the leaf levels. + """ + rng = np.random.default_rng() + + def generate_branch(depth: int) -> gcl.base.MaskBase: + if (depth == 0) or ( + random.choice((True, False)) and not (depth == max_depth) + ): + return gcl.MaskArray( + rng.integers(0, 2, size=leaf_length, dtype=np.bool_) + ) + num_children = random.randint(1, max_width) + return gcl.MaskGroup( + { + random_alphanum(10): generate_branch(depth - 1) + for _ in range(num_children) + } + ) + + return generate_branch(max_depth) + + +def test_maskgroup_serialize_inverse() -> None: + """Tests that instantiating a MaskGroup from its serialization + yields identical results to the original. + """ + invertible = True + for _ in range(10): + maskgroup = generate_tree( + max_width=5, + max_depth=10, + leaf_length=random.randint(0, 1_000), + ) + invertible &= maskgroup.equal_to(gcl.MaskGroup(maskgroup.serialize())) + assert invertible, "Serializing MaskGroups is not invertible."