Skip to content

Commit

Permalink
Test various fields with clumerge()
Browse files Browse the repository at this point in the history
  • Loading branch information
nunofachada committed Jun 19, 2023
1 parent 91c9bd1 commit 82cbde5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def pytest_addoption(parser):
t_ds_ot_n: Sequence[int]
t_ds_od_n: Sequence[int]
t_no_clusters_field: Sequence[bool]
t_ds_cgs_n: Sequence[int]


def pytest_report_header(config):
Expand Down Expand Up @@ -128,6 +129,7 @@ def pytest_generate_tests(metafunc):
t_ds_ot_n = [0]
t_ds_od_n = [0, 1]
t_no_clusters_field = (False,)
t_ds_cgs_n = [2]
elif test_level == "ci":
# CI test level
seeds = [123]
Expand All @@ -151,6 +153,7 @@ def pytest_generate_tests(metafunc):
t_ds_ot_n = [0, 1]
t_ds_od_n = [0, 1]
t_no_clusters_field = [False, True]
t_ds_cgs_n = [2, 3]
elif test_level == "normal":
seeds = [0, 123, 6789]
t_ndims = [1, 2, 3, 10]
Expand All @@ -173,6 +176,7 @@ def pytest_generate_tests(metafunc):
t_ds_ot_n = [0, 1]
t_ds_od_n = [0, 1, 2]
t_no_clusters_field = [False, True]
t_ds_cgs_n = [2, 3, 4]
elif test_level == "full":
seeds = [0, 123, 6789, 9876543]
t_ndims = [1, 2, 3, 5, 10, 30]
Expand All @@ -195,6 +199,7 @@ def pytest_generate_tests(metafunc):
t_ds_ot_n = [0, 1, 2]
t_ds_od_n = [0, 1, 2]
t_no_clusters_field = [False, True]
t_ds_cgs_n = [2, 3, 4, 5]
else:
raise ValueError(f"Unknown test level {test_level!r}")

Expand Down Expand Up @@ -224,6 +229,7 @@ def param_if(param: str, value: Sequence[Any]):
param_if("ds_ot_n", t_ds_ot_n)
param_if("ds_od_n", t_ds_od_n)
param_if("no_clusters_field", t_no_clusters_field)
param_if("ds_cgs_n", t_ds_cgs_n)


@pytest.fixture()
Expand Down
70 changes: 70 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,3 +928,73 @@ def test_clumerge_general(
assert mds["points"].shape == expect_shape
assert max(mds["clusters"]) == tclu
assert can_cast(mds["clusters"].dtype, int64)


def test_clumerge_fields(
prng: Generator,
ndims,
ds_cgs_n,
):
"""# Test clumerge with data from clugen() and merging more fields."""
datasets: MutableSequence[NamedTuple | Mapping[str, ArrayLike]] = []
tclu: int = 0
tclu_i: int = 0
tpts: int = 0

# Create data sets with clugen()
for _ in range(ds_cgs_n):
# clugen() should run without problem
with warnings.catch_warnings():
# Check that the function runs without warnings
warnings.simplefilter("error")

ds_cgs = clugen(
ndims,
prng.integers(1, high=11),
prng.integers(1, high=101),
prng.random(size=ndims),
prng.random(),
prng.random(size=ndims),
prng.random(),
prng.random(),
prng.random(),
allow_empty=True,
rng=prng,
)

tclu += len(unique(ds_cgs.clusters))
tpts += len(ds_cgs.points)
tclu_i += len(ds_cgs.sizes)
datasets.append(ds_cgs)

# Check that clumerge() is able to merge data set fields related to points
# without warnings
with warnings.catch_warnings():
# Check that the function runs without warnings
warnings.simplefilter("error")
mds = clumerge(*datasets, fields=("points", "clusters", "projections"))

# Check that the number of clusters and points is correct
expect_shape = (tpts,) if ndims == 1 else (tpts, ndims)
assert mds["points"].shape == expect_shape
assert mds["projections"].shape == expect_shape
assert max(mds["clusters"]) == tclu
assert can_cast(mds["clusters"].dtype, int64)

# Check that clumerge() is able to merge data set fields related to clusters
# without warnings
with warnings.catch_warnings():
# Check that the function runs without warnings
warnings.simplefilter("error")
mds = clumerge(*datasets, fields=("sizes", "centers", "directions", "angles", "lengths"),
clusters_field=None,
)

# Check that the cluster-related fields have the correct sizes
expect_shape = (tclu_i,) if ndims == 1 else (tclu_i, ndims)
assert len(mds["sizes"]) == tclu_i
assert can_cast(mds["sizes"], int64)
assert mds["centers"].shape == expect_shape
assert mds["directions"].shape == expect_shape
assert len(mds["angles"]) == tclu_i
assert len(mds["lengths"]) == tclu_i

0 comments on commit 82cbde5

Please sign in to comment.