From 82cbde55eed52f6c5976b73ffb9234a22bf8bb03 Mon Sep 17 00:00:00 2001 From: Nuno Fachada Date: Mon, 19 Jun 2023 23:59:08 +0100 Subject: [PATCH] Test various fields with clumerge() --- tests/conftest.py | 6 ++++ tests/test_main.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 8fb190f..4d44eea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -95,6 +95,7 @@ def pytest_addoption(parser): t_ds_ot_n: Sequence[int] t_ds_od_n: Sequence[int] t_no_clusters_field: Sequence[bool] +t_ds_cgs_n: Sequence[int] def pytest_report_header(config): @@ -128,6 +129,7 @@ def pytest_generate_tests(metafunc): t_ds_ot_n = [0] t_ds_od_n = [0, 1] t_no_clusters_field = (False,) + t_ds_cgs_n = [2] elif test_level == "ci": # CI test level seeds = [123] @@ -151,6 +153,7 @@ def pytest_generate_tests(metafunc): t_ds_ot_n = [0, 1] t_ds_od_n = [0, 1] t_no_clusters_field = [False, True] + t_ds_cgs_n = [2, 3] elif test_level == "normal": seeds = [0, 123, 6789] t_ndims = [1, 2, 3, 10] @@ -173,6 +176,7 @@ def pytest_generate_tests(metafunc): t_ds_ot_n = [0, 1] t_ds_od_n = [0, 1, 2] t_no_clusters_field = [False, True] + t_ds_cgs_n = [2, 3, 4] elif test_level == "full": seeds = [0, 123, 6789, 9876543] t_ndims = [1, 2, 3, 5, 10, 30] @@ -195,6 +199,7 @@ def pytest_generate_tests(metafunc): t_ds_ot_n = [0, 1, 2] t_ds_od_n = [0, 1, 2] t_no_clusters_field = [False, True] + t_ds_cgs_n = [2, 3, 4, 5] else: raise ValueError(f"Unknown test level {test_level!r}") @@ -224,6 +229,7 @@ def param_if(param: str, value: Sequence[Any]): param_if("ds_ot_n", t_ds_ot_n) param_if("ds_od_n", t_ds_od_n) param_if("no_clusters_field", t_no_clusters_field) + param_if("ds_cgs_n", t_ds_cgs_n) @pytest.fixture() diff --git a/tests/test_main.py b/tests/test_main.py index de71c4f..5d9e4fb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -928,3 +928,73 @@ def test_clumerge_general( assert mds["points"].shape == expect_shape assert max(mds["clusters"]) == tclu assert can_cast(mds["clusters"].dtype, int64) + + +def test_clumerge_fields( + prng: Generator, + ndims, + ds_cgs_n, +): + """# Test clumerge with data from clugen() and merging more fields.""" + datasets: MutableSequence[NamedTuple | Mapping[str, ArrayLike]] = [] + tclu: int = 0 + tclu_i: int = 0 + tpts: int = 0 + + # Create data sets with clugen() + for _ in range(ds_cgs_n): + # clugen() should run without problem + with warnings.catch_warnings(): + # Check that the function runs without warnings + warnings.simplefilter("error") + + ds_cgs = clugen( + ndims, + prng.integers(1, high=11), + prng.integers(1, high=101), + prng.random(size=ndims), + prng.random(), + prng.random(size=ndims), + prng.random(), + prng.random(), + prng.random(), + allow_empty=True, + rng=prng, + ) + + tclu += len(unique(ds_cgs.clusters)) + tpts += len(ds_cgs.points) + tclu_i += len(ds_cgs.sizes) + datasets.append(ds_cgs) + + # Check that clumerge() is able to merge data set fields related to points + # without warnings + with warnings.catch_warnings(): + # Check that the function runs without warnings + warnings.simplefilter("error") + mds = clumerge(*datasets, fields=("points", "clusters", "projections")) + + # Check that the number of clusters and points is correct + expect_shape = (tpts,) if ndims == 1 else (tpts, ndims) + assert mds["points"].shape == expect_shape + assert mds["projections"].shape == expect_shape + assert max(mds["clusters"]) == tclu + assert can_cast(mds["clusters"].dtype, int64) + + # Check that clumerge() is able to merge data set fields related to clusters + # without warnings + with warnings.catch_warnings(): + # Check that the function runs without warnings + warnings.simplefilter("error") + mds = clumerge(*datasets, fields=("sizes", "centers", "directions", "angles", "lengths"), + clusters_field=None, + ) + + # Check that the cluster-related fields have the correct sizes + expect_shape = (tclu_i,) if ndims == 1 else (tclu_i, ndims) + assert len(mds["sizes"]) == tclu_i + assert can_cast(mds["sizes"], int64) + assert mds["centers"].shape == expect_shape + assert mds["directions"].shape == expect_shape + assert len(mds["angles"]) == tclu_i + assert len(mds["lengths"]) == tclu_i