Skip to content

Commit

Permalink
Implement span_normalise for divmat
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 15, 2023
1 parent 7190c2d commit 844083f
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 51 deletions.
35 changes: 25 additions & 10 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,10 @@ verify_mean_descendants(tsk_treeseq_t *ts)
}

/* Check the divergence matrix by running against the stats API equivalent
* code. NOTE: this will not always be equal in site mode, because of a slightly
* different definition wrt to multiple mutations at a site.
* code.
*/
static void
verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t mode)
verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t options)
{
int ret;
const tsk_size_t n = tsk_treeseq_get_num_samples(ts);
Expand All @@ -285,10 +284,10 @@ verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t mode)
}
}
ret = tsk_treeseq_divergence(
ts, n, sample_set_sizes, samples, n * n, index_tuples, 0, NULL, mode, D1);
ts, n, sample_set_sizes, samples, n * n, index_tuples, 0, NULL, options, D1);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_divergence_matrix(ts, 0, NULL, 0, NULL, mode, D2);
ret = tsk_treeseq_divergence_matrix(ts, 0, NULL, 0, NULL, options, D2);
CU_ASSERT_EQUAL_FATAL(ret, 0);

for (j = 0; j < n; j++) {
Expand Down Expand Up @@ -1072,7 +1071,9 @@ test_single_tree_divergence_matrix(void)
assert_arrays_almost_equal(16, result, D_site);

verify_divergence_matrix(&ts, TSK_STAT_BRANCH);
verify_divergence_matrix(&ts, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE);
verify_divergence_matrix(&ts, TSK_STAT_SITE);
verify_divergence_matrix(&ts, TSK_STAT_SITE | TSK_STAT_SPAN_NORMALISE);

tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -1120,7 +1121,9 @@ test_single_tree_divergence_matrix_internal_samples(void)
assert_arrays_almost_equal(16, result, D);

verify_divergence_matrix(&ts, TSK_STAT_BRANCH);
verify_divergence_matrix(&ts, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE);
verify_divergence_matrix(&ts, TSK_STAT_SITE);
verify_divergence_matrix(&ts, TSK_STAT_SITE | TSK_STAT_SPAN_NORMALISE);

tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -1165,8 +1168,10 @@ test_single_tree_divergence_matrix_multi_root(void)
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(16, result, D_branch);

verify_divergence_matrix(&ts, TSK_STAT_SITE);
verify_divergence_matrix(&ts, TSK_STAT_BRANCH);
verify_divergence_matrix(&ts, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE);
verify_divergence_matrix(&ts, TSK_STAT_SITE);
verify_divergence_matrix(&ts, TSK_STAT_SITE | TSK_STAT_SPAN_NORMALISE);

tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -1839,7 +1844,9 @@ test_paper_ex_divergence_matrix(void)
paper_ex_mutations, paper_ex_individuals, NULL, 0);

verify_divergence_matrix(&ts, TSK_STAT_BRANCH);
verify_divergence_matrix(&ts, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE);
verify_divergence_matrix(&ts, TSK_STAT_SITE);
verify_divergence_matrix(&ts, TSK_STAT_SITE | TSK_STAT_SPAN_NORMALISE);

tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -1999,10 +2006,20 @@ test_simplest_divergence_matrix(void)
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D_branch, result);

ret = tsk_treeseq_divergence_matrix(
&ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D_branch, result);

ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D_site, result);

ret = tsk_treeseq_divergence_matrix(
&ts, 2, sample_ids, 0, NULL, TSK_STAT_SPAN_NORMALISE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D_site, result);

ret = tsk_treeseq_divergence_matrix(
&ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
Expand All @@ -2019,10 +2036,6 @@ test_simplest_divergence_matrix(void)
ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE);

ret = tsk_treeseq_divergence_matrix(
&ts, 0, NULL, 0, NULL, TSK_STAT_SPAN_NORMALISE, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED);

ret = tsk_treeseq_divergence_matrix(
&ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_POLARISED_UNSUPPORTED);
Expand Down Expand Up @@ -2146,7 +2159,9 @@ test_multiroot_divergence_matrix(void)
multiroot_ex_sites, multiroot_ex_mutations, NULL, NULL, 0);

verify_divergence_matrix(&ts, TSK_STAT_BRANCH);
verify_divergence_matrix(&ts, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE);
verify_divergence_matrix(&ts, TSK_STAT_SITE);
verify_divergence_matrix(&ts, TSK_STAT_SITE | TSK_STAT_SPAN_NORMALISE);

tsk_treeseq_free(&ts);
}
Expand Down
7 changes: 3 additions & 4 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6809,10 +6809,6 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
ret = TSK_ERR_STAT_POLARISED_UNSUPPORTED;
goto out;
}
if (options & TSK_STAT_SPAN_NORMALISE) {
ret = TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED;
goto out;
}

if (windows == NULL) {
num_windows = 1;
Expand Down Expand Up @@ -6855,6 +6851,9 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
}
fill_lower_triangle(result, n, num_windows);

if (options & TSK_STAT_SPAN_NORMALISE) {
span_normalise(num_windows, windows, n * n, result);
}
out:
tsk_safe_free(sample_index_map);
return ret;
Expand Down
12 changes: 9 additions & 3 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9737,7 +9737,7 @@ static PyObject *
TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
{
PyObject *ret = NULL;
static char *kwlist[] = { "windows", "samples", "mode", NULL };
static char *kwlist[] = { "windows", "samples", "mode", "span_normalise", NULL };
PyArrayObject *result_array = NULL;
PyObject *windows = NULL;
PyObject *py_samples = Py_None;
Expand All @@ -9748,13 +9748,14 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
npy_intp *shape, dims[3];
tsk_size_t num_samples, num_windows;
tsk_id_t *samples = NULL;
int span_normalise = 0;
int err;

if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(
args, kwds, "O|Os", kwlist, &windows, &py_samples, &mode)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|Osi", kwlist, &windows, &py_samples,
&mode, &span_normalise)) {
goto out;
}
num_samples = tsk_treeseq_get_num_samples(self->tree_sequence);
Expand All @@ -9778,9 +9779,14 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
if (result_array == NULL) {
goto out;
}

if (parse_stats_mode(mode, &options) != 0) {
goto out;
}
if (span_normalise) {
options |= TSK_STAT_SPAN_NORMALISE;
}

// clang-format off
Py_BEGIN_ALLOW_THREADS
err = tsk_treeseq_divergence_matrix(
Expand Down
Loading

0 comments on commit 844083f

Please sign in to comment.