Skip to content

Commit

Permalink
Basic working version of site divmat
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 5, 2023
1 parent 65381d9 commit 71034e7
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 44 deletions.
266 changes: 234 additions & 32 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6279,17 +6279,16 @@ tsk_treeseq_check_node_bounds(
return ret;
}

int
tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows,
tsk_flags_t TSK_UNUSED(options), double *result)
static int
tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *restrict samples, tsk_size_t num_windows,
const double *restrict windows, tsk_flags_t TSK_UNUSED(options),
double *restrict result)
{
int ret = 0;
tsk_tree_t tree;
const tsk_id_t *restrict samples = self->samples;
const double default_windows[] = { 0, self->tables->sequence_length };
const double *restrict nodes_time = self->tables->nodes.time;
tsk_size_t n = self->num_samples;
const tsk_size_t n = num_samples;
tsk_size_t i, j, k;
tsk_id_t u, v, w, u_root, v_root;
double tu, tv, d, span, left, right, span_left, span_right;
Expand All @@ -6306,27 +6305,6 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
goto out;
}

if (windows == NULL) {
num_windows = 1;
windows = default_windows;
} else {
ret = tsk_treeseq_check_windows(self, num_windows, windows, 0);
if (ret != 0) {
goto out;
}
}

if (samples_in != NULL) {
samples = samples_in;
n = num_samples;
ret = tsk_treeseq_check_node_bounds(self, n, samples);
if (ret != 0) {
goto out;
}
}

tsk_memset(result, 0, num_windows * n * n * sizeof(*result));

for (i = 0; i < num_windows; i++) {
left = windows[i];
right = windows[i + 1];
Expand Down Expand Up @@ -6365,16 +6343,240 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
goto out;
}
}
/* TODO there's probably a better striding pattern that could be used here */
}
ret = 0;
out:
tsk_tree_free(&tree);
sv_tables_free(&sv);
return ret;
}

static tsk_size_t
count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent,
const double *restrict time, const tsk_size_t *restrict mutations_per_node)
{
double tu, tv;
tsk_size_t count = 0;

tu = time[u];
tv = time[v];
while (u != v) {
if (tu < tv) {
count += mutations_per_node[u];
u = parent[u];
if (u == TSK_NULL) {
break;
}
tu = time[u];
} else {
count += mutations_per_node[v];
v = parent[v];
if (v == TSK_NULL) {
break;
}
tv = time[v];
}
}
tsk_bug_assert((u == TSK_NULL) == (v == TSK_NULL));
return count;
}

static int
tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *restrict samples, tsk_size_t num_windows,
const double *restrict windows, tsk_flags_t TSK_UNUSED(options),
double *restrict result)
{
int ret = 0;
tsk_tree_t tree;
const tsk_size_t n = num_samples;
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
const double *restrict nodes_time = self->tables->nodes.time;
tsk_size_t i, j, k, tree_site, tree_mut;
tsk_site_t site;
tsk_mutation_t mut;
tsk_id_t u, v;
double left, right, span_left, span_right;
double *restrict D;
tsk_size_t *mutations_per_node = malloc(num_nodes * sizeof(*mutations_per_node));

ret = tsk_tree_init(&tree, self, 0);
if (ret != 0) {
goto out;
}
if (mutations_per_node == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

for (i = 0; i < num_windows; i++) {
left = windows[i];
right = windows[i + 1];
D = result + i * n * n;
ret = tsk_tree_seek(&tree, left, 0);
if (ret != 0) {
goto out;
}
while (tree.interval.left < right && tree.index != -1) {
span_left = TSK_MAX(tree.interval.left, left);
span_right = TSK_MIN(tree.interval.right, right);

/* NOTE: we could avoid this full memset across all nodes by doing
* the same loops again and decrementing at the end of the main
* tree-loop. It's probably not worth it though, because of the
* overwhelming O(n^2) below */
tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node));
for (tree_site = 0; tree_site < tree.sites_length; tree_site++) {
site = tree.sites[tree_site];
if (span_left <= site.position && site.position < span_right) {
for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) {
mut = site.mutations[tree_mut];
mutations_per_node[mut.node]++;
}
}
}

for (j = 0; j < n; j++) {
u = samples[j];
for (k = j + 1; k < n; k++) {
v = samples[k];
D[j * n + k] += (double) count_mutations_on_path(
u, v, tree.parent, nodes_time, mutations_per_node);
}
}
ret = tsk_tree_next(&tree);
if (ret < 0) {
goto out;
}
}
}
ret = 0;

/* n = len(samples) */
/* D = np.zeros((num_windows, n, n)) */
/* tree = tskit.Tree(ts) */
/* for i in range(num_windows): */
/* left = windows[i] */
/* right = windows[i + 1] */
/* tree.seek(left) */
/* # Iterate over the trees in this window */
/* while tree.interval.left < right and tree.index != -1: */
/* span_left = max(tree.interval.left, left) */
/* span_right = min(tree.interval.right, right) */
/* mutations_per_node = collections.Counter() */
/* for site in tree.sites(): */
/* if span_left <= site.position < span_right: */
/* for mutation in site.mutations: */
/* mutations_per_node[mutation.node] += 1 */
/* for j in range(n): */
/* u = samples[j] */
/* for k in range(j + 1, n): */
/* v = samples[k] */
/* w = tree.mrca(u, v) */
/* if w != tskit.NULL: */
/* wu = w */
/* wv = w */
/* else: */
/* wu = local_root(tree, u) */
/* wv = local_root(tree, v) */
/* du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) */
/* dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) */
/* # NOTE: we're just accumulating the raw mutation counts, not */
/* # multiplying by span */
/* D[i, j, k] += du + dv */
/* tree.next() */
out:
tsk_tree_free(&tree);
tsk_safe_free(mutations_per_node);
return ret;
}

static void
fill_lower_triangle(
double *restrict result, const tsk_size_t n, const tsk_size_t num_windows)
{
tsk_size_t i, j, k;
double *restrict D;

/* TODO there's probably a better striding pattern that could be used here */
for (i = 0; i < num_windows; i++) {
D = result + i * n * n;
for (j = 0; j < n; j++) {
for (k = j + 1; k < n; k++) {
D[k * n + j] = D[j * n + k];
}
}
}
ret = 0;
}

int
tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows,
tsk_flags_t options, double *result)
{
int ret = 0;
const tsk_id_t *samples = self->samples;
tsk_size_t n = self->num_samples;
const double default_windows[] = { 0, self->tables->sequence_length };
bool stat_site = !!(options & TSK_STAT_SITE);
bool stat_branch = !!(options & TSK_STAT_BRANCH);
bool stat_node = !!(options & TSK_STAT_NODE);

if (stat_node) {
ret = TSK_ERR_UNSUPPORTED_STAT_MODE;
goto out;
}
/* If no mode is specified, we default to site mode */
if (!(stat_site || stat_branch)) {
// FIXME - wrong default!!
stat_branch = true;
}
/* It's an error to specify more than one mode */
if (stat_site + stat_branch > 1) {
ret = TSK_ERR_MULTIPLE_STAT_MODES;
goto out;
}

if ((options & TSK_STAT_POLARISED) || (options & TSK_STAT_SPAN_NORMALISE)) {
/* TODO better error */
ret = TSK_ERR_BAD_PARAM_VALUE;
goto out;
}

if (windows == NULL) {
num_windows = 1;
windows = default_windows;
} else {
ret = tsk_treeseq_check_windows(self, num_windows, windows, 0);
if (ret != 0) {
goto out;
}
}

if (samples_in != NULL) {
samples = samples_in;
n = num_samples;
ret = tsk_treeseq_check_node_bounds(self, n, samples);
if (ret != 0) {
goto out;
}
}

tsk_memset(result, 0, num_windows * n * n * sizeof(*result));

if (stat_branch) {
ret = tsk_treeseq_divergence_matrix_branch(
self, n, samples, num_windows, windows, options, result);
} else {
tsk_bug_assert(stat_site);
ret = tsk_treeseq_divergence_matrix_site(
self, n, samples, num_windows, windows, options, result);
}
if (ret != 0) {
goto out;
}
fill_lower_triangle(result, n, num_windows);

out:
tsk_tree_free(&tree);
sv_tables_free(&sv);
return ret;
}
9 changes: 7 additions & 2 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9642,10 +9642,11 @@ static PyObject *
TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
{
PyObject *ret = NULL;
static char *kwlist[] = { "windows", "samples", NULL };
static char *kwlist[] = { "windows", "samples", "mode", NULL };
PyArrayObject *result_array = NULL;
PyObject *windows = NULL;
PyObject *py_samples = Py_None;
char *mode = NULL;
PyArrayObject *windows_array = NULL;
PyArrayObject *samples_array = NULL;
tsk_flags_t options = 0;
Expand All @@ -9657,7 +9658,8 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &windows, &py_samples)) {
if (!PyArg_ParseTupleAndKeywords(
args, kwds, "O|Os", kwlist, &windows, &py_samples, &mode)) {
goto out;
}
num_samples = tsk_treeseq_get_num_samples(self->tree_sequence);
Expand All @@ -9681,6 +9683,9 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
if (result_array == NULL) {
goto out;
}
if (parse_stats_mode(mode, &options) != 0) {
goto out;
}
// clang-format off
Py_BEGIN_ALLOW_THREADS
err = tsk_treeseq_divergence_matrix(
Expand Down
2 changes: 2 additions & 0 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,8 @@ def test_simulation_example(self):
print(ts.num_mutations)
D1 = site_divergence_matrix_naive(ts)
print(D1)
D2 = ts.divergence_matrix(mode="site")
print(D2)


class TestThreadsNoWindows:
Expand Down
Loading

0 comments on commit 71034e7

Please sign in to comment.