diff --git a/tests/test_data.py b/tests/test_data.py index 07f7943..fce7859 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -9,6 +9,7 @@ import dataclasses as dc import math import random +import string import numpy as np import pytest @@ -18,6 +19,13 @@ ZERO_TOL = 1.0e-10 # absolute tolerance for detecting zero-values +def random_alphanum(length: int) -> str: + return "".join( + random.choice(string.ascii_letters + string.digits) + for _ in range(length) + ) + + @dc.dataclass class MomentumExample: """Dataclass for representing the four-momentum of a particle, based @@ -111,3 +119,57 @@ def test_pmu_zero_pt() -> None: with pytest.warns(gcl.base.NumericalStabilityWarning): phi_invalid = math.isnan(pmu_zero_pt.phi.item()) assert phi_invalid, "Azimuth is not NaN when pT is low" + + +def generate_tree( + max_width: int, max_depth: int, leaf_length: int +) -> gcl.MaskGroup: + """Generates a nested MaskGroup tree structure, with random branch + widths and depths. + + Parameters + ---------- + max_width, max_depth : int + Maximum limits on the nested tree structure. + leaf_length : int + The number of elements in the leaf MaskArrays. + + Returns + ------- + MaskGroup + Tree structure, with random structure, and random MaskArrays at + the leaf levels. + """ + rng = np.random.default_rng() + + def generate_branch(depth: int) -> gcl.base.MaskBase: + if (depth == 0) or ( + random.choice((True, False)) and not (depth == max_depth) + ): + return gcl.MaskArray( + rng.integers(0, 2, size=leaf_length, dtype=np.bool_) + ) + num_children = random.randint(1, max_width) + return gcl.MaskGroup( + { + random_alphanum(10): generate_branch(depth - 1) + for _ in range(num_children) + } + ) + + return generate_branch(max_depth) + + +def test_maskgroup_serialize_inverse() -> None: + """Tests that instantiating a MaskGroup from its serialization + yields identical results to the original. + """ + invertible = True + for _ in range(10): + maskgroup = generate_tree( + max_width=5, + max_depth=10, + leaf_length=random.randint(0, 1_000), + ) + invertible &= maskgroup.equal_to(gcl.MaskGroup(maskgroup.serialize())) + assert invertible, "Serializing MaskGroups is not invertible."