From 6ba3df39d92547b4a2287dd02fc0cc16de9d8a8c Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 6 Jul 2023 10:00:25 +0100 Subject: [PATCH] Some tests for site mode --- c/tests/test_trees.c | 47 +++++++++++++++---- c/tskit/trees.c | 34 -------------- python/tests/test_divmat.py | 86 ++++++++++++++++++++--------------- python/tests/test_lowlevel.py | 4 ++ 4 files changed, 93 insertions(+), 78 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 95187e2d61..cf8bc3fb56 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -3802,19 +3802,45 @@ test_simplest_divergence_matrix(void) const char *edges = "0 1 2 0,1\n"; tsk_treeseq_t ts; tsk_id_t sample_ids[] = { 0, 1 }; - double D[4] = { 0, 2, 2, 0 }; + double D_branch[4] = { 0, 2, 2, 0 }; + double D_site[4] = { 0, 0, 0, 0 }; double result[4]; int ret; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - assert_arrays_almost_equal(4, D, result); + assert_arrays_almost_equal(4, D_branch, result); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, 0, D); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - assert_arrays_almost_equal(4, D, result); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SPAN_NORMALISE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); sample_ids[0] = -1; ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); @@ -5393,7 +5419,8 @@ test_single_tree_divergence_matrix_multi_root(void) tsk_treeseq_t ts; int ret; double result[16]; - double D[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 }; + double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 }; + double D_site[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; const char *nodes = "1 0 -1 -1\n" "1 0 -1 -1\n" /* 2.00┊ 5 ┊ */ @@ -5406,9 +5433,13 @@ test_single_tree_divergence_matrix_multi_root(void) tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - assert_arrays_almost_equal(16, result, D); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); tsk_treeseq_free(&ts); } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 9c53ffca67..9383cbc738 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6460,40 +6460,6 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam } } ret = 0; - -/* n = len(samples) */ -/* D = np.zeros((num_windows, n, n)) */ -/* tree = tskit.Tree(ts) */ -/* for i in range(num_windows): */ -/* left = windows[i] */ -/* right = windows[i + 1] */ -/* tree.seek(left) */ -/* # Iterate over the trees in this window */ -/* while tree.interval.left < right and tree.index != -1: */ -/* span_left = max(tree.interval.left, left) */ -/* span_right = min(tree.interval.right, right) */ -/* mutations_per_node = collections.Counter() */ -/* for site in tree.sites(): */ -/* if span_left <= site.position < span_right: */ -/* for mutation in site.mutations: */ -/* mutations_per_node[mutation.node] += 1 */ -/* for j in range(n): */ -/* u = samples[j] */ -/* for k in range(j + 1, n): */ -/* v = samples[k] */ -/* w = tree.mrca(u, v) */ -/* if w != tskit.NULL: */ -/* wu = w */ -/* wv = w */ -/* else: */ -/* wu = local_root(tree, u) */ -/* wv = local_root(tree, v) */ -/* du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) */ -/* dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) */ -/* # NOTE: we're just accumulating the raw mutation counts, not */ -/* # multiplying by span */ -/* D[i, j, k] += du + dv */ -/* tree.next() */ out: tsk_tree_free(&tree); tsk_safe_free(mutations_per_node); diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index fce75d8d80..7ab7cafc12 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -205,7 +205,7 @@ def divergence_matrix(ts, windows=None, samples=None): return D -def stats_api_divergence_matrix(ts, windows=None, samples=None): +def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="branch"): samples = ts.samples() if samples is None else samples windows_specified = windows is not None windows = [0, ts.sequence_length] if windows is None else list(windows) @@ -247,7 +247,7 @@ def stats_api_divergence_matrix(ts, windows=None, samples=None): X = ts.divergence( sample_sets, indexes=indexes, - mode="branch", + mode=mode, span_normalise=False, windows=windows, ) @@ -721,8 +721,11 @@ def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): num_threads=num_threads, mode=mode, ) + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + assert D1.shape == D2.shape if mode == "branch": - D2 = stats_api_divergence_matrix(ts, windows=windows, samples=samples) # If we have missing data then parts of the divmat are defined to be zero, # so relative tolerances aren't useful. Because the stats API # method necessarily involves subtracting away all of the previous @@ -730,17 +733,14 @@ def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): # here. This value for atol is what is needed to get the tests to # pass in practise. has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees()) - assert D1.shape == D2.shape atol = 1e-12 if has_missing_data else 0 np.testing.assert_allclose(D1, D2, atol=atol) else: assert mode == "site" - D2 = site_divergence_matrix_naive(ts, windows=windows, samples=samples) - # print("D1 = ") - # print(D1) - # print("D2 = ") - # print(D2) - assert D1.shape == D2.shape + if np.any(ts.mutations_parent != tskit.NULL): + # The stats API computes something slightly different when we have + # recurrent mutations, so fall back to the naive version. + D2 = site_divergence_matrix_naive(ts, windows=windows, samples=samples) np.testing.assert_array_equal(D1, D2) @pytest.mark.parametrize("ts", get_example_tree_sequences()) @@ -749,23 +749,27 @@ def test_defaults(self, ts, mode): self.check(ts, mode=mode) @pytest.mark.parametrize("ts", get_example_tree_sequences()) - def test_subset_samples(self, ts): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_subset_samples(self, ts, mode): n = min(ts.num_samples, 2) - self.check(ts, samples=ts.samples()[:n]) + self.check(ts, samples=ts.samples()[:n], mode=mode) @pytest.mark.parametrize("ts", get_example_tree_sequences()) - def test_windows(self, ts): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_windows(self, ts, mode): windows = np.linspace(0, ts.sequence_length, num=13) - self.check(ts, windows=windows) + self.check(ts, windows=windows, mode=mode) @pytest.mark.parametrize("ts", get_example_tree_sequences()) - def test_threads_no_windows(self, ts): - self.check(ts, num_threads=5) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_no_windows(self, ts, mode): + self.check(ts, num_threads=5, mode=mode) @pytest.mark.parametrize("ts", get_example_tree_sequences()) - def test_threads_windows(self, ts): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_windows(self, ts, mode): windows = np.linspace(0, ts.sequence_length, num=11) - self.check(ts, num_threads=5, windows=windows) + self.check(ts, num_threads=5, windows=windows, mode=mode) class TestSiteDivergence: @@ -780,26 +784,29 @@ def test_simulation_example(self): class TestThreadsNoWindows: - def check(self, ts, num_threads, samples=None): - D1 = ts.divergence_matrix(num_threads=0, samples=samples) - D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples) + def check(self, ts, num_threads, samples=None, mode=None): + D1 = ts.divergence_matrix(num_threads=0, samples=samples, mode=mode) + D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples, mode=mode) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) - def test_all_trees(self, num_threads): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, mode): ts = tsutil.all_trees_ts(4) assert ts.num_trees == 26 - self.check(ts, num_threads) + self.check(ts, num_threads, mode=mode) @pytest.mark.parametrize("samples", [None, [0, 1]]) - def test_all_trees_samples(self, samples): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, mode): ts = tsutil.all_trees_ts(4) assert ts.num_trees == 26 - self.check(ts, 2, samples) + self.check(ts, 2, samples, mode=mode) @pytest.mark.parametrize("n", [2, 3, 5, 15]) @pytest.mark.parametrize("num_threads", range(1, 5)) - def test_simple_sims(self, n, num_threads): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, n, num_threads, mode): ts = msprime.sim_ancestry( n, ploidy=1, @@ -809,14 +816,16 @@ def test_simple_sims(self, n, num_threads): random_seed=1234, ) assert ts.num_trees >= 2 - self.check(ts, num_threads) + self.check(ts, num_threads, mode=mode) class TestThreadsWindows: - def check(self, ts, num_threads, *, windows, samples=None): - D1 = ts.divergence_matrix(num_threads=0, windows=windows, samples=samples) + def check(self, ts, num_threads, *, windows, samples=None, mode=None): + D1 = ts.divergence_matrix( + num_threads=0, windows=windows, samples=samples, mode=mode + ) D2 = ts.divergence_matrix( - num_threads=num_threads, windows=windows, samples=samples + num_threads=num_threads, windows=windows, samples=samples, mode=mode ) np.testing.assert_array_almost_equal(D1, D2) @@ -832,10 +841,11 @@ def check(self, ts, num_threads, *, windows, samples=None): ([5.1, 5.2, 6.5],), ], ) - def test_all_trees(self, num_threads, windows): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, windows, mode): ts = tsutil.all_trees_ts(4) assert ts.num_trees == 26 - self.check(ts, num_threads, windows=windows) + self.check(ts, num_threads, windows=windows, mode=mode) @pytest.mark.parametrize("samples", [None, [0, 1]]) @pytest.mark.parametrize( @@ -845,9 +855,10 @@ def test_all_trees(self, num_threads, windows): (None,), ], ) - def test_all_trees_samples(self, samples, windows): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, windows, mode): ts = tsutil.all_trees_ts(4) - self.check(ts, 2, windows=windows, samples=samples) + self.check(ts, 2, windows=windows, samples=samples, mode=mode) @pytest.mark.parametrize("num_threads", range(1, 5)) @pytest.mark.parametrize( @@ -860,7 +871,8 @@ def test_all_trees_samples(self, samples, windows): (list(range(100)),), ], ) - def test_simple_sims(self, num_threads, windows): + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, num_threads, windows, mode): ts = msprime.sim_ancestry( 15, ploidy=1, @@ -870,7 +882,9 @@ def test_simple_sims(self, num_threads, windows): random_seed=1234, ) assert ts.num_trees >= 2 - self.check(ts, num_threads, windows=windows) + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1234) + assert ts.num_mutations > 10 + self.check(ts, num_threads, windows=windows, mode=mode) # NOTE these are tests that are for more general functionality that might diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 29f4efaa34..530ec7223f 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1544,6 +1544,10 @@ def test_divergence_matrix(self): ts.divergence_matrix(windows=[-1, 0, 1]) with pytest.raises(ValueError): ts.divergence_matrix(windows=[0, 1], samples="sdf") + with pytest.raises(ValueError, match="Unrecognised stats mode"): + ts.divergence_matrix(windows=[0, 1], mode="sdf") + with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"): + ts.divergence_matrix(windows=[0, 1], mode="node") def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences():