Skip to content

Commit

Permalink
Some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 5, 2023
1 parent 71034e7 commit a5f9756
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
11 changes: 10 additions & 1 deletion c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6377,7 +6377,16 @@ count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent,
tv = time[v];
}
}
tsk_bug_assert((u == TSK_NULL) == (v == TSK_NULL));
if (u != v) {
while (u != TSK_NULL) {
count += mutations_per_node[u];
u = parent[u];
}
while (v != TSK_NULL) {
count += mutations_per_node[v];
v = parent[v];
}
}
return count;
}

Expand Down
48 changes: 32 additions & 16 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from tests import tsutil
from tests.test_highlevel import get_example_tree_sequences

DIVMAT_MODES = ["branch", "site"]

# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when
# we can remove this.

Expand Down Expand Up @@ -712,25 +714,39 @@ class TestSuiteExamples:
Python code above on.
"""

def check(self, ts, windows=None, samples=None, num_threads=0):
D1 = stats_api_divergence_matrix(ts, windows=windows, samples=samples)
D2 = ts.divergence_matrix(
windows=windows, samples=samples, num_threads=num_threads
def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"):
D1 = ts.divergence_matrix(
windows=windows,
samples=samples,
num_threads=num_threads,
mode=mode,
)
# If we have missing data then parts of the divmat are defined to be zero,
# so relative tolerances aren't useful. Because the stats API
# method necessarily involves subtracting away all of the previous
# values for an empty tree, there is a degree of numerical imprecision
# here. This value for atol is what is needed to get the tests to
# pass in practise.
has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees())
assert D1.shape == D2.shape
atol = 1e-12 if has_missing_data else 0
np.testing.assert_allclose(D1, D2, atol=atol)
if mode == "branch":
D2 = stats_api_divergence_matrix(ts, windows=windows, samples=samples)
# If we have missing data then parts of the divmat are defined to be zero,
# so relative tolerances aren't useful. Because the stats API
# method necessarily involves subtracting away all of the previous
# values for an empty tree, there is a degree of numerical imprecision
# here. This value for atol is what is needed to get the tests to
# pass in practise.
has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees())
assert D1.shape == D2.shape
atol = 1e-12 if has_missing_data else 0
np.testing.assert_allclose(D1, D2, atol=atol)
else:
assert mode == "site"
D2 = site_divergence_matrix_naive(ts, windows=windows, samples=samples)
# print("D1 = ")
# print(D1)
# print("D2 = ")
# print(D2)
assert D1.shape == D2.shape
np.testing.assert_array_equal(D1, D2)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_defaults(self, ts):
self.check(ts)
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_defaults(self, ts, mode):
self.check(ts, mode=mode)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_subset_samples(self, ts):
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def get_gap_examples():
assert len(t.parent_dict) == 0
found = True
assert found
ret.append((f"gap {x}", ts))
ret.append((f"gap_{x}", ts))
# Give an example with a gap at the end.
ts = msprime.simulate(10, random_seed=5, recombination_rate=1)
tables = get_table_collection_copy(ts.dump_tables(), 2)
Expand Down

0 comments on commit a5f9756

Please sign in to comment.