diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 8a3d0afc95..4fdf83b78d 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6597,43 +6597,73 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s return ret; } -static tsk_size_t -count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, - const double *restrict time, const tsk_size_t *restrict mutations_per_node) +static void +increment_divergence_matrix_pairs(const tsk_size_t len_A, const tsk_id_t *restrict A, + const tsk_size_t len_B, const tsk_id_t *restrict B, double *restrict D) { - double tu, tv; - tsk_size_t count = 0; - - tu = time[u]; - tv = time[v]; - while (u != v) { - if (tu < tv) { - count += mutations_per_node[u]; - u = parent[u]; - if (u == TSK_NULL) { - break; - } - tu = time[u]; - } else { - count += mutations_per_node[v]; - v = parent[v]; - if (v == TSK_NULL) { - break; + tsk_id_t u, v; + tsk_size_t j, k; + const tsk_id_t n = (tsk_id_t)(len_A + len_B); + + for (j = 0; j < len_A; j++) { + for (k = 0; k < len_B; k++) { + u = A[j]; + v = B[k]; + /* Only increment the upper triangle to (hopefully) improve memory + * access patterns */ + if (u > v) { + v = A[j]; + u = B[k]; } - tv = time[v]; + D[u * n + v]++; } } - if (u != v) { - while (u != TSK_NULL) { - count += mutations_per_node[u]; - u = parent[u]; +} + +static void +update_site_divergence(const tsk_tree_t *tree, tsk_id_t node, + const tsk_id_t *sample_index_map, tsk_size_t num_samples, tsk_id_t *restrict stack, + int8_t *restrict descending_bitset, tsk_id_t *restrict descending_list, + tsk_id_t *restrict not_descending_list, double *D) +{ + const tsk_id_t *restrict left_child = tree->left_child; + const tsk_id_t *restrict right_sib = tree->right_sib; + int stack_top; + tsk_id_t a, u, v; + tsk_size_t j, num_descending, num_not_descending; + + tsk_memset(descending_bitset, 0, num_samples * sizeof(*descending_bitset)); + + stack_top = 0; + stack[stack_top] = node; + while (stack_top >= 0) { + u = stack[stack_top]; + stack_top--; + a = sample_index_map[u]; + if (a != TSK_NULL) { + descending_bitset[a] = 1; } - while (v != TSK_NULL) { - count += mutations_per_node[v]; - v = parent[v]; + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { + stack_top++; + stack[stack_top] = v; } } - return count; + + num_descending = 0; + num_not_descending = 0; + for (j = 0; j < num_samples; j++) { + if (descending_bitset[j]) { + descending_list[num_descending] = (tsk_id_t) j; + num_descending++; + } else { + not_descending_list[num_not_descending] = (tsk_id_t) j; + num_not_descending++; + } + } + tsk_bug_assert(num_descending + num_not_descending == num_samples); + + increment_divergence_matrix_pairs( + num_descending, descending_list, num_not_descending, not_descending_list, D); } static int @@ -6644,30 +6674,44 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam { int ret = 0; tsk_tree_t tree; - const tsk_size_t n = num_samples; const tsk_size_t num_nodes = self->tables->nodes.num_rows; - const double *restrict nodes_time = self->tables->nodes.time; - tsk_size_t i, j, k, tree_site, tree_mut; + tsk_size_t i, j, tree_site, tree_mut; tsk_site_t site; tsk_mutation_t mut; - tsk_id_t u, v; + tsk_id_t u; double left, right, span_left, span_right; double *restrict D; - tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node)); + int8_t *descending_bitset = tsk_malloc(num_samples * sizeof(*descending_bitset)); + tsk_id_t *descending_list = tsk_malloc(num_samples * sizeof(*descending_list)); + tsk_id_t *not_descending_list + = tsk_malloc(num_samples * sizeof(*not_descending_list)); + tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map)); + tsk_id_t *stack = NULL; ret = tsk_tree_init(&tree, self, 0); if (ret != 0) { goto out; } - if (mutations_per_node == NULL) { + + stack = tsk_malloc(tsk_tree_get_size_bound(&tree) * sizeof(*stack)); + if (descending_bitset == NULL || descending_list == NULL + || not_descending_list == NULL || sample_index_map == NULL || stack == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } + for (j = 0; j < num_nodes; j++) { + sample_index_map[j] = TSK_NULL; + } + for (j = 0; j < num_samples; j++) { + u = samples[j]; + /* TODO CHECK FOR DUPS */ + sample_index_map[u] = (tsk_id_t) j; + } for (i = 0; i < num_windows; i++) { left = windows[i]; right = windows[i + 1]; - D = result + i * n * n; + D = result + i * num_samples * num_samples; ret = tsk_tree_seek(&tree, left, 0); if (ret != 0) { goto out; @@ -6676,29 +6720,18 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam span_left = TSK_MAX(tree.interval.left, left); span_right = TSK_MIN(tree.interval.right, right); - /* NOTE: we could avoid this full memset across all nodes by doing - * the same loops again and decrementing at the end of the main - * tree-loop. It's probably not worth it though, because of the - * overwhelming O(n^2) below */ - tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node)); for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { site = tree.sites[tree_site]; if (span_left <= site.position && site.position < span_right) { for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { mut = site.mutations[tree_mut]; - mutations_per_node[mut.node]++; + update_site_divergence(&tree, mut.node, sample_index_map, + num_samples, stack, descending_bitset, descending_list, + not_descending_list, D); } } } - for (j = 0; j < n; j++) { - u = samples[j]; - for (k = j + 1; k < n; k++) { - v = samples[k]; - D[j * n + k] += (double) count_mutations_on_path( - u, v, tree.parent, nodes_time, mutations_per_node); - } - } ret = tsk_tree_next(&tree); if (ret < 0) { goto out; @@ -6708,7 +6741,11 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam ret = 0; out: tsk_tree_free(&tree); - tsk_safe_free(mutations_per_node); + tsk_safe_free(sample_index_map); + tsk_safe_free(descending_bitset); + tsk_safe_free(descending_list); + tsk_safe_free(not_descending_list); + tsk_safe_free(stack); return ret; }