diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 78b758ac25..9c53ffca67 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6377,7 +6377,16 @@ count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, tv = time[v]; } } - tsk_bug_assert((u == TSK_NULL) == (v == TSK_NULL)); + if (u != v) { + while (u != TSK_NULL) { + count += mutations_per_node[u]; + u = parent[u]; + } + while (v != TSK_NULL) { + count += mutations_per_node[v]; + v = parent[v]; + } + } return count; } diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index 3196f9757c..fce75d8d80 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -32,6 +32,8 @@ from tests import tsutil from tests.test_highlevel import get_example_tree_sequences +DIVMAT_MODES = ["branch", "site"] + # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. @@ -712,25 +714,39 @@ class TestSuiteExamples: Python code above on. """ - def check(self, ts, windows=None, samples=None, num_threads=0): - D1 = stats_api_divergence_matrix(ts, windows=windows, samples=samples) - D2 = ts.divergence_matrix( - windows=windows, samples=samples, num_threads=num_threads + def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): + D1 = ts.divergence_matrix( + windows=windows, + samples=samples, + num_threads=num_threads, + mode=mode, ) - # If we have missing data then parts of the divmat are defined to be zero, - # so relative tolerances aren't useful. Because the stats API - # method necessarily involves subtracting away all of the previous - # values for an empty tree, there is a degree of numerical imprecision - # here. This value for atol is what is needed to get the tests to - # pass in practise. - has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees()) - assert D1.shape == D2.shape - atol = 1e-12 if has_missing_data else 0 - np.testing.assert_allclose(D1, D2, atol=atol) + if mode == "branch": + D2 = stats_api_divergence_matrix(ts, windows=windows, samples=samples) + # If we have missing data then parts of the divmat are defined to be zero, + # so relative tolerances aren't useful. Because the stats API + # method necessarily involves subtracting away all of the previous + # values for an empty tree, there is a degree of numerical imprecision + # here. This value for atol is what is needed to get the tests to + # pass in practise. + has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees()) + assert D1.shape == D2.shape + atol = 1e-12 if has_missing_data else 0 + np.testing.assert_allclose(D1, D2, atol=atol) + else: + assert mode == "site" + D2 = site_divergence_matrix_naive(ts, windows=windows, samples=samples) + # print("D1 = ") + # print(D1) + # print("D2 = ") + # print(D2) + assert D1.shape == D2.shape + np.testing.assert_array_equal(D1, D2) @pytest.mark.parametrize("ts", get_example_tree_sequences()) - def test_defaults(self, ts): - self.check(ts) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_defaults(self, ts, mode): + self.check(ts, mode=mode) @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_subset_samples(self, ts): diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 15fd256444..0529dda001 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -228,7 +228,7 @@ def get_gap_examples(): assert len(t.parent_dict) == 0 found = True assert found - ret.append((f"gap {x}", ts)) + ret.append((f"gap_{x}", ts)) # Give an example with a gap at the end. ts = msprime.simulate(10, random_seed=5, recombination_rate=1) tables = get_table_collection_copy(ts.dump_tables(), 2)