Skip to content

Commit

Permalink
test MaskGroup serialization is invertible #172
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Feb 28, 2024
1 parent 0326f60 commit 9b83b8d
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dataclasses as dc
import math
import random
import string

import numpy as np
import pytest
Expand All @@ -18,6 +19,13 @@
ZERO_TOL = 1.0e-10 # absolute tolerance for detecting zero-values


def random_alphanum(length: int) -> str:
return "".join(
random.choice(string.ascii_letters + string.digits)
for _ in range(length)
)


@dc.dataclass
class MomentumExample:
"""Dataclass for representing the four-momentum of a particle, based
Expand Down Expand Up @@ -111,3 +119,57 @@ def test_pmu_zero_pt() -> None:
with pytest.warns(gcl.base.NumericalStabilityWarning):
phi_invalid = math.isnan(pmu_zero_pt.phi.item())
assert phi_invalid, "Azimuth is not NaN when pT is low"


def generate_tree(
max_width: int, max_depth: int, leaf_length: int
) -> gcl.MaskGroup:
"""Generates a nested MaskGroup tree structure, with random branch
widths and depths.
Parameters
----------
max_width, max_depth : int
Maximum limits on the nested tree structure.
leaf_length : int
The number of elements in the leaf MaskArrays.
Returns
-------
MaskGroup
Tree structure, with random structure, and random MaskArrays at
the leaf levels.
"""
rng = np.random.default_rng()

def generate_branch(depth: int) -> gcl.base.MaskBase:
if (depth == 0) or (
random.choice((True, False)) and not (depth == max_depth)
):
return gcl.MaskArray(
rng.integers(0, 2, size=leaf_length, dtype=np.bool_)
)
num_children = random.randint(1, max_width)
return gcl.MaskGroup(
{
random_alphanum(10): generate_branch(depth - 1)
for _ in range(num_children)
}
)

return generate_branch(max_depth)


def test_maskgroup_serialize_inverse() -> None:
"""Tests that instantiating a MaskGroup from its serialization
yields identical results to the original.
"""
invertible = True
for _ in range(10):
maskgroup = generate_tree(
max_width=5,
max_depth=10,
leaf_length=random.randint(0, 1_000),
)
invertible &= maskgroup.equal_to(gcl.MaskGroup(maskgroup.serialize()))
assert invertible, "Serializing MaskGroups is not invertible."

0 comments on commit 9b83b8d

Please sign in to comment.