Skip to content

Commit

Permalink
Some tests for site mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 6, 2023
1 parent a5f9756 commit 6ba3df3
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 78 deletions.
47 changes: 39 additions & 8 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -3802,19 +3802,45 @@ test_simplest_divergence_matrix(void)
const char *edges = "0 1 2 0,1\n";
tsk_treeseq_t ts;
tsk_id_t sample_ids[] = { 0, 1 };
double D[4] = { 0, 2, 2, 0 };
double D_branch[4] = { 0, 2, 2, 0 };
double D_site[4] = { 0, 0, 0, 0 };
double result[4];
int ret;

tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);

ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
ret = tsk_treeseq_divergence_matrix(
&ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D, result);
assert_arrays_almost_equal(4, D_branch, result);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, 0, D);
ret = tsk_treeseq_divergence_matrix(
&ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D_site, result);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D, result);
assert_arrays_almost_equal(4, D_branch, result);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D_site, result);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE);

ret = tsk_treeseq_divergence_matrix(
&ts, 0, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES);

ret = tsk_treeseq_divergence_matrix(
&ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE);

ret = tsk_treeseq_divergence_matrix(
&ts, 0, NULL, 0, NULL, TSK_STAT_SPAN_NORMALISE, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE);

sample_ids[0] = -1;
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
Expand Down Expand Up @@ -5393,7 +5419,8 @@ test_single_tree_divergence_matrix_multi_root(void)
tsk_treeseq_t ts;
int ret;
double result[16];
double D[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 };
double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 };
double D_site[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };

const char *nodes = "1 0 -1 -1\n"
"1 0 -1 -1\n" /* 2.00┊ 5 ┊ */
Expand All @@ -5406,9 +5433,13 @@ test_single_tree_divergence_matrix_multi_root(void)

tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, 0, result);
ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(16, result, D);
assert_arrays_almost_equal(16, result, D_branch);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(16, result, D_site);

tsk_treeseq_free(&ts);
}
Expand Down
34 changes: 0 additions & 34 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6460,40 +6460,6 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
}
}
ret = 0;

/* n = len(samples) */
/* D = np.zeros((num_windows, n, n)) */
/* tree = tskit.Tree(ts) */
/* for i in range(num_windows): */
/* left = windows[i] */
/* right = windows[i + 1] */
/* tree.seek(left) */
/* # Iterate over the trees in this window */
/* while tree.interval.left < right and tree.index != -1: */
/* span_left = max(tree.interval.left, left) */
/* span_right = min(tree.interval.right, right) */
/* mutations_per_node = collections.Counter() */
/* for site in tree.sites(): */
/* if span_left <= site.position < span_right: */
/* for mutation in site.mutations: */
/* mutations_per_node[mutation.node] += 1 */
/* for j in range(n): */
/* u = samples[j] */
/* for k in range(j + 1, n): */
/* v = samples[k] */
/* w = tree.mrca(u, v) */
/* if w != tskit.NULL: */
/* wu = w */
/* wv = w */
/* else: */
/* wu = local_root(tree, u) */
/* wv = local_root(tree, v) */
/* du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) */
/* dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) */
/* # NOTE: we're just accumulating the raw mutation counts, not */
/* # multiplying by span */
/* D[i, j, k] += du + dv */
/* tree.next() */
out:
tsk_tree_free(&tree);
tsk_safe_free(mutations_per_node);
Expand Down
86 changes: 50 additions & 36 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def divergence_matrix(ts, windows=None, samples=None):
return D


def stats_api_divergence_matrix(ts, windows=None, samples=None):
def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="branch"):
samples = ts.samples() if samples is None else samples
windows_specified = windows is not None
windows = [0, ts.sequence_length] if windows is None else list(windows)
Expand Down Expand Up @@ -247,7 +247,7 @@ def stats_api_divergence_matrix(ts, windows=None, samples=None):
X = ts.divergence(
sample_sets,
indexes=indexes,
mode="branch",
mode=mode,
span_normalise=False,
windows=windows,
)
Expand Down Expand Up @@ -721,26 +721,26 @@ def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"):
num_threads=num_threads,
mode=mode,
)
D2 = stats_api_divergence_matrix(
ts, windows=windows, samples=samples, mode=mode
)
assert D1.shape == D2.shape
if mode == "branch":
D2 = stats_api_divergence_matrix(ts, windows=windows, samples=samples)
# If we have missing data then parts of the divmat are defined to be zero,
# so relative tolerances aren't useful. Because the stats API
# method necessarily involves subtracting away all of the previous
# values for an empty tree, there is a degree of numerical imprecision
# here. This value for atol is what is needed to get the tests to
# pass in practise.
has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees())
assert D1.shape == D2.shape
atol = 1e-12 if has_missing_data else 0
np.testing.assert_allclose(D1, D2, atol=atol)
else:
assert mode == "site"
D2 = site_divergence_matrix_naive(ts, windows=windows, samples=samples)
# print("D1 = ")
# print(D1)
# print("D2 = ")
# print(D2)
assert D1.shape == D2.shape
if np.any(ts.mutations_parent != tskit.NULL):
# The stats API computes something slightly different when we have
# recurrent mutations, so fall back to the naive version.
D2 = site_divergence_matrix_naive(ts, windows=windows, samples=samples)
np.testing.assert_array_equal(D1, D2)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
Expand All @@ -749,23 +749,27 @@ def test_defaults(self, ts, mode):
self.check(ts, mode=mode)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_subset_samples(self, ts):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_subset_samples(self, ts, mode):
n = min(ts.num_samples, 2)
self.check(ts, samples=ts.samples()[:n])
self.check(ts, samples=ts.samples()[:n], mode=mode)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_windows(self, ts):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_windows(self, ts, mode):
windows = np.linspace(0, ts.sequence_length, num=13)
self.check(ts, windows=windows)
self.check(ts, windows=windows, mode=mode)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_threads_no_windows(self, ts):
self.check(ts, num_threads=5)
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_threads_no_windows(self, ts, mode):
self.check(ts, num_threads=5, mode=mode)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_threads_windows(self, ts):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_threads_windows(self, ts, mode):
windows = np.linspace(0, ts.sequence_length, num=11)
self.check(ts, num_threads=5, windows=windows)
self.check(ts, num_threads=5, windows=windows, mode=mode)


class TestSiteDivergence:
Expand All @@ -780,26 +784,29 @@ def test_simulation_example(self):


class TestThreadsNoWindows:
def check(self, ts, num_threads, samples=None):
D1 = ts.divergence_matrix(num_threads=0, samples=samples)
D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples)
def check(self, ts, num_threads, samples=None, mode=None):
D1 = ts.divergence_matrix(num_threads=0, samples=samples, mode=mode)
D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples, mode=mode)
np.testing.assert_array_almost_equal(D1, D2)

@pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27])
def test_all_trees(self, num_threads):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_all_trees(self, num_threads, mode):
ts = tsutil.all_trees_ts(4)
assert ts.num_trees == 26
self.check(ts, num_threads)
self.check(ts, num_threads, mode=mode)

@pytest.mark.parametrize("samples", [None, [0, 1]])
def test_all_trees_samples(self, samples):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_all_trees_samples(self, samples, mode):
ts = tsutil.all_trees_ts(4)
assert ts.num_trees == 26
self.check(ts, 2, samples)
self.check(ts, 2, samples, mode=mode)

@pytest.mark.parametrize("n", [2, 3, 5, 15])
@pytest.mark.parametrize("num_threads", range(1, 5))
def test_simple_sims(self, n, num_threads):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_simple_sims(self, n, num_threads, mode):
ts = msprime.sim_ancestry(
n,
ploidy=1,
Expand All @@ -809,14 +816,16 @@ def test_simple_sims(self, n, num_threads):
random_seed=1234,
)
assert ts.num_trees >= 2
self.check(ts, num_threads)
self.check(ts, num_threads, mode=mode)


class TestThreadsWindows:
def check(self, ts, num_threads, *, windows, samples=None):
D1 = ts.divergence_matrix(num_threads=0, windows=windows, samples=samples)
def check(self, ts, num_threads, *, windows, samples=None, mode=None):
D1 = ts.divergence_matrix(
num_threads=0, windows=windows, samples=samples, mode=mode
)
D2 = ts.divergence_matrix(
num_threads=num_threads, windows=windows, samples=samples
num_threads=num_threads, windows=windows, samples=samples, mode=mode
)
np.testing.assert_array_almost_equal(D1, D2)

Expand All @@ -832,10 +841,11 @@ def check(self, ts, num_threads, *, windows, samples=None):
([5.1, 5.2, 6.5],),
],
)
def test_all_trees(self, num_threads, windows):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_all_trees(self, num_threads, windows, mode):
ts = tsutil.all_trees_ts(4)
assert ts.num_trees == 26
self.check(ts, num_threads, windows=windows)
self.check(ts, num_threads, windows=windows, mode=mode)

@pytest.mark.parametrize("samples", [None, [0, 1]])
@pytest.mark.parametrize(
Expand All @@ -845,9 +855,10 @@ def test_all_trees(self, num_threads, windows):
(None,),
],
)
def test_all_trees_samples(self, samples, windows):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_all_trees_samples(self, samples, windows, mode):
ts = tsutil.all_trees_ts(4)
self.check(ts, 2, windows=windows, samples=samples)
self.check(ts, 2, windows=windows, samples=samples, mode=mode)

@pytest.mark.parametrize("num_threads", range(1, 5))
@pytest.mark.parametrize(
Expand All @@ -860,7 +871,8 @@ def test_all_trees_samples(self, samples, windows):
(list(range(100)),),
],
)
def test_simple_sims(self, num_threads, windows):
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_simple_sims(self, num_threads, windows, mode):
ts = msprime.sim_ancestry(
15,
ploidy=1,
Expand All @@ -870,7 +882,9 @@ def test_simple_sims(self, num_threads, windows):
random_seed=1234,
)
assert ts.num_trees >= 2
self.check(ts, num_threads, windows=windows)
ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1234)
assert ts.num_mutations > 10
self.check(ts, num_threads, windows=windows, mode=mode)


# NOTE these are tests that are for more general functionality that might
Expand Down
4 changes: 4 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,10 @@ def test_divergence_matrix(self):
ts.divergence_matrix(windows=[-1, 0, 1])
with pytest.raises(ValueError):
ts.divergence_matrix(windows=[0, 1], samples="sdf")
with pytest.raises(ValueError, match="Unrecognised stats mode"):
ts.divergence_matrix(windows=[0, 1], mode="sdf")
with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"):
ts.divergence_matrix(windows=[0, 1], mode="node")

def test_load_tables_build_indexes(self):
for ts in self.get_example_tree_sequences():
Expand Down

0 comments on commit 6ba3df3

Please sign in to comment.