From 35fef5cfd741c4917fe839088181aa13526b8e5d Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 19 Jul 2023 06:24:36 +0100 Subject: [PATCH] Site divmat version with O(n^2) per mutation Closes #2779 --- c/tests/test_stats.c | 7 ++ c/tskit/trees.c | 179 +++++++++++++++++++++++++----------- python/tests/test_divmat.py | 50 ++++------ 3 files changed, 147 insertions(+), 89 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 39f2a063e0..348a04f705 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2041,6 +2041,13 @@ test_simplest_divergence_matrix(void) ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + sample_ids[0] = 1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + tsk_treeseq_free(&ts); } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 8a3d0afc95..6a914582aa 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6597,69 +6597,104 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s return ret; } -static tsk_size_t -count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, - const double *restrict time, const tsk_size_t *restrict mutations_per_node) +static void +increment_divergence_matrix_pairs(const tsk_size_t len_A, const tsk_id_t *restrict A, + const tsk_size_t len_B, const tsk_id_t *restrict B, double *restrict D) { - double tu, tv; - tsk_size_t count = 0; - - tu = time[u]; - tv = time[v]; - while (u != v) { - if (tu < tv) { - count += mutations_per_node[u]; - u = parent[u]; - if (u == TSK_NULL) { - break; - } - tu = time[u]; - } else { - count += mutations_per_node[v]; - v = parent[v]; - if (v == TSK_NULL) { - break; + tsk_id_t u, v; + tsk_size_t j, k; + const tsk_id_t n = (tsk_id_t)(len_A + len_B); + + for (j = 0; j < len_A; j++) { + for (k = 0; k < len_B; k++) { + u = A[j]; + v = B[k]; + /* Only increment the upper triangle to (hopefully) improve memory + * access patterns */ + if (u > v) { + v = A[j]; + u = B[k]; } - tv = time[v]; + D[u * n + v]++; } } - if (u != v) { - while (u != TSK_NULL) { - count += mutations_per_node[u]; - u = parent[u]; +} + +static void +update_site_divergence(const tsk_tree_t *tree, tsk_id_t node, + const tsk_id_t *sample_index_map, tsk_size_t num_samples, tsk_id_t *restrict stack, + int8_t *restrict descending_bitset, tsk_id_t *restrict descending_list, + tsk_id_t *restrict not_descending_list, double *D) +{ + const tsk_id_t *restrict left_child = tree->left_child; + const tsk_id_t *restrict right_sib = tree->right_sib; + int stack_top; + tsk_id_t a, u, v; + tsk_size_t j, num_descending, num_not_descending; + + tsk_memset(descending_bitset, 0, num_samples * sizeof(*descending_bitset)); + + stack_top = 0; + stack[stack_top] = node; + while (stack_top >= 0) { + u = stack[stack_top]; + stack_top--; + a = sample_index_map[u]; + if (a != TSK_NULL) { + descending_bitset[a] = 1; } - while (v != TSK_NULL) { - count += mutations_per_node[v]; - v = parent[v]; + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { + stack_top++; + stack[stack_top] = v; + } + } + + num_descending = 0; + num_not_descending = 0; + for (j = 0; j < num_samples; j++) { + if (descending_bitset[j]) { + descending_list[num_descending] = (tsk_id_t) j; + num_descending++; + } else { + not_descending_list[num_not_descending] = (tsk_id_t) j; + num_not_descending++; } } - return count; + tsk_bug_assert(num_descending + num_not_descending == num_samples); + + increment_divergence_matrix_pairs( + num_descending, descending_list, num_not_descending, not_descending_list, D); } static int tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples, - const tsk_id_t *restrict samples, tsk_size_t num_windows, - const double *restrict windows, tsk_flags_t TSK_UNUSED(options), + tsk_size_t num_windows, const double *restrict windows, + tsk_flags_t TSK_UNUSED(options), const tsk_id_t *restrict sample_index_map, double *restrict result) { int ret = 0; tsk_tree_t tree; - const tsk_size_t n = num_samples; - const tsk_size_t num_nodes = self->tables->nodes.num_rows; - const double *restrict nodes_time = self->tables->nodes.time; - tsk_size_t i, j, k, tree_site, tree_mut; + tsk_size_t i, tree_site, tree_mut; tsk_site_t site; tsk_mutation_t mut; - tsk_id_t u, v; double left, right, span_left, span_right; double *restrict D; - tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node)); + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + int8_t *descending_bitset = tsk_malloc(num_samples * sizeof(*descending_bitset)); + tsk_id_t *descending_list = tsk_malloc(num_samples * sizeof(*descending_list)); + tsk_id_t *not_descending_list + = tsk_malloc(num_samples * sizeof(*not_descending_list)); + /* Do *not* use tsk_tree_get_size bound here because it gives a per-tree + * bound, not a global one! */ + tsk_id_t *stack = tsk_malloc(num_nodes * sizeof(*stack)); ret = tsk_tree_init(&tree, self, 0); if (ret != 0) { goto out; } - if (mutations_per_node == NULL) { + + if (descending_bitset == NULL || descending_list == NULL + || not_descending_list == NULL || stack == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } @@ -6667,7 +6702,7 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam for (i = 0; i < num_windows; i++) { left = windows[i]; right = windows[i + 1]; - D = result + i * n * n; + D = result + i * num_samples * num_samples; ret = tsk_tree_seek(&tree, left, 0); if (ret != 0) { goto out; @@ -6676,29 +6711,18 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam span_left = TSK_MAX(tree.interval.left, left); span_right = TSK_MIN(tree.interval.right, right); - /* NOTE: we could avoid this full memset across all nodes by doing - * the same loops again and decrementing at the end of the main - * tree-loop. It's probably not worth it though, because of the - * overwhelming O(n^2) below */ - tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node)); for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { site = tree.sites[tree_site]; if (span_left <= site.position && site.position < span_right) { for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { mut = site.mutations[tree_mut]; - mutations_per_node[mut.node]++; + update_site_divergence(&tree, mut.node, sample_index_map, + num_samples, stack, descending_bitset, descending_list, + not_descending_list, D); } } } - for (j = 0; j < n; j++) { - u = samples[j]; - for (k = j + 1; k < n; k++) { - v = samples[k]; - D[j * n + k] += (double) count_mutations_on_path( - u, v, tree.parent, nodes_time, mutations_per_node); - } - } ret = tsk_tree_next(&tree); if (ret < 0) { goto out; @@ -6708,7 +6732,42 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam ret = 0; out: tsk_tree_free(&tree); - tsk_safe_free(mutations_per_node); + tsk_safe_free(descending_bitset); + tsk_safe_free(descending_list); + tsk_safe_free(not_descending_list); + tsk_safe_free(stack); + return ret; +} + +static int +get_sample_index_map(const tsk_size_t num_nodes, const tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_id_t **ret_sample_index_map) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map)); + + if (sample_index_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* Assign the output pointer here so that it will be freed in the case + * of an error raised in the input checking */ + *ret_sample_index_map = sample_index_map; + + for (j = 0; j < num_nodes; j++) { + sample_index_map[j] = TSK_NULL; + } + for (j = 0; j < num_samples; j++) { + u = samples[j]; + if (sample_index_map[u] != TSK_NULL) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + sample_index_map[u] = (tsk_id_t) j; + } +out: return ret; } @@ -6739,9 +6798,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, const tsk_id_t *samples = self->samples; tsk_size_t n = self->num_samples; const double default_windows[] = { 0, self->tables->sequence_length }; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); + tsk_id_t *sample_index_map = NULL; if (stat_node) { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; @@ -6785,6 +6846,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, } } + ret = get_sample_index_map(num_nodes, n, samples, &sample_index_map); + if (ret != 0) { + goto out; + } + tsk_memset(result, 0, num_windows * n * n * sizeof(*result)); if (stat_branch) { @@ -6793,7 +6859,7 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, } else { tsk_bug_assert(stat_site); ret = tsk_treeseq_divergence_matrix_site( - self, n, samples, num_windows, windows, options, result); + self, n, num_windows, windows, options, sample_index_map, result); } if (ret != 0) { goto out; @@ -6801,5 +6867,6 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, fill_lower_triangle(result, n, num_windows); out: + tsk_safe_free(sample_index_map); return ret; } diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index acb2403d41..c6df718f83 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -22,8 +22,6 @@ """ Test cases for divergence matrix based pairwise stats """ -import collections - import msprime import numpy as np import pytest @@ -279,6 +277,9 @@ def site_divergence_matrix(ts, windows=None, samples=None): samples = ts.samples() if samples is None else samples n = len(samples) + sample_index_map = np.zeros(ts.num_nodes, dtype=int) - 1 + sample_index_map[samples] = np.arange(n) + is_descendant = np.zeros(n, dtype=bool) D = np.zeros((num_windows, n, n)) tree = tskit.Tree(ts) for i in range(num_windows): @@ -289,32 +290,22 @@ def site_divergence_matrix(ts, windows=None, samples=None): while tree.interval.left < right and tree.index != -1: span_left = max(tree.interval.left, left) span_right = min(tree.interval.right, right) - mutations_per_node = collections.Counter() for site in tree.sites(): if span_left <= site.position < span_right: for mutation in site.mutations: - mutations_per_node[mutation.node] += 1 - for j in range(n): - u = samples[j] - for k in range(j + 1, n): - v = samples[k] - w = tree.mrca(u, v) - if w != tskit.NULL: - wu = w - wv = w - else: - wu = local_root(tree, u) - wv = local_root(tree, v) - du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) - dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) - # NOTE: we're just accumulating the raw mutation counts, not - # multiplying by span - D[i, j, k] += du + dv + descendants = [] + for u in tree.nodes(mutation.node): + if sample_index_map[u] != -1: + is_descendant[sample_index_map[u]] = True + + descendants = np.where(is_descendant)[0] + not_descendants = np.where(np.logical_not(is_descendant))[0] + for j in descendants: + for k in not_descendants: + D[i, j, k] += 1 + D[i, k, j] += 1 + is_descendant[:] = False tree.next() - # Fill out symmetric triangle in the matrix - for j in range(n): - for k in range(j + 1, n): - D[i, k, j] = D[i, j, k] if not windows_specified: D = D[0] return D @@ -511,15 +502,8 @@ def test_single_tree_duplicate_samples(self, mode): # 0 1 ts = tskit.Tree.generate_balanced(4).tree_sequence ts = tsutil.insert_branch_sites(ts) - D1 = check_divmat(ts, samples=[0, 0, 1], compare_stats_api=False, mode=mode) - D2 = np.array( - [ - [0.0, 0.0, 2.0], - [0.0, 0.0, 2.0], - [2.0, 2.0, 0.0], - ] - ) - np.testing.assert_array_equal(D1, D2) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): + ts.divergence_matrix(samples=[0, 0, 1], mode=mode) @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_multiroot(self, mode):