Skip to content

Commit

Permalink
Improved C algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 4, 2023
1 parent 042200f commit 73dc0cf
Showing 1 changed file with 89 additions and 52 deletions.
141 changes: 89 additions & 52 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6597,43 +6597,73 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s
return ret;
}

static tsk_size_t
count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent,
const double *restrict time, const tsk_size_t *restrict mutations_per_node)
static void
increment_divergence_matrix_pairs(const tsk_size_t len_A, const tsk_id_t *restrict A,
const tsk_size_t len_B, const tsk_id_t *restrict B, double *restrict D)
{
double tu, tv;
tsk_size_t count = 0;

tu = time[u];
tv = time[v];
while (u != v) {
if (tu < tv) {
count += mutations_per_node[u];
u = parent[u];
if (u == TSK_NULL) {
break;
}
tu = time[u];
} else {
count += mutations_per_node[v];
v = parent[v];
if (v == TSK_NULL) {
break;
tsk_id_t u, v;
tsk_size_t j, k;
const tsk_id_t n = (tsk_id_t)(len_A + len_B);

for (j = 0; j < len_A; j++) {
for (k = 0; k < len_B; k++) {
u = A[j];
v = B[k];
/* Only increment the upper triangle to (hopefully) improve memory
* access patterns */
if (u > v) {
v = A[j];
u = B[k];
}
tv = time[v];
D[u * n + v]++;
}
}
if (u != v) {
while (u != TSK_NULL) {
count += mutations_per_node[u];
u = parent[u];
}

static void
update_site_divergence(const tsk_tree_t *tree, tsk_id_t node,
const tsk_id_t *sample_index_map, tsk_size_t num_samples, tsk_id_t *restrict stack,
int8_t *restrict descending_bitset, tsk_id_t *restrict descending_list,
tsk_id_t *restrict not_descending_list, double *D)
{
const tsk_id_t *restrict left_child = tree->left_child;
const tsk_id_t *restrict right_sib = tree->right_sib;
int stack_top;
tsk_id_t a, u, v;
tsk_size_t j, num_descending, num_not_descending;

tsk_memset(descending_bitset, 0, num_samples * sizeof(*descending_bitset));

stack_top = 0;
stack[stack_top] = node;
while (stack_top >= 0) {
u = stack[stack_top];
stack_top--;
a = sample_index_map[u];
if (a != TSK_NULL) {
descending_bitset[a] = 1;
}
while (v != TSK_NULL) {
count += mutations_per_node[v];
v = parent[v];
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
stack_top++;
stack[stack_top] = v;
}
}
return count;

num_descending = 0;
num_not_descending = 0;
for (j = 0; j < num_samples; j++) {
if (descending_bitset[j]) {
descending_list[num_descending] = (tsk_id_t) j;
num_descending++;
} else {
not_descending_list[num_not_descending] = (tsk_id_t) j;
num_not_descending++;
}
}
tsk_bug_assert(num_descending + num_not_descending == num_samples);

increment_divergence_matrix_pairs(
num_descending, descending_list, num_not_descending, not_descending_list, D);
}

static int
Expand All @@ -6644,30 +6674,44 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
{
int ret = 0;
tsk_tree_t tree;
const tsk_size_t n = num_samples;
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
const double *restrict nodes_time = self->tables->nodes.time;
tsk_size_t i, j, k, tree_site, tree_mut;
tsk_size_t i, j, tree_site, tree_mut;
tsk_site_t site;
tsk_mutation_t mut;
tsk_id_t u, v;
tsk_id_t u;
double left, right, span_left, span_right;
double *restrict D;
tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node));
int8_t *descending_bitset = tsk_malloc(num_samples * sizeof(*descending_bitset));
tsk_id_t *descending_list = tsk_malloc(num_samples * sizeof(*descending_list));
tsk_id_t *not_descending_list
= tsk_malloc(num_samples * sizeof(*not_descending_list));
tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map));
tsk_id_t *stack = NULL;

ret = tsk_tree_init(&tree, self, 0);
if (ret != 0) {
goto out;
}
if (mutations_per_node == NULL) {

stack = tsk_malloc(tsk_tree_get_size_bound(&tree) * sizeof(*stack));
if (descending_bitset == NULL || descending_list == NULL
|| not_descending_list == NULL || sample_index_map == NULL || stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
for (j = 0; j < num_nodes; j++) {
sample_index_map[j] = TSK_NULL;
}
for (j = 0; j < num_samples; j++) {
u = samples[j];
/* TODO CHECK FOR DUPS */
sample_index_map[u] = (tsk_id_t) j;
}

for (i = 0; i < num_windows; i++) {
left = windows[i];
right = windows[i + 1];
D = result + i * n * n;
D = result + i * num_samples * num_samples;
ret = tsk_tree_seek(&tree, left, 0);
if (ret != 0) {
goto out;
Expand All @@ -6676,29 +6720,18 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
span_left = TSK_MAX(tree.interval.left, left);
span_right = TSK_MIN(tree.interval.right, right);

/* NOTE: we could avoid this full memset across all nodes by doing
* the same loops again and decrementing at the end of the main
* tree-loop. It's probably not worth it though, because of the
* overwhelming O(n^2) below */
tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node));
for (tree_site = 0; tree_site < tree.sites_length; tree_site++) {
site = tree.sites[tree_site];
if (span_left <= site.position && site.position < span_right) {
for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) {
mut = site.mutations[tree_mut];
mutations_per_node[mut.node]++;
update_site_divergence(&tree, mut.node, sample_index_map,
num_samples, stack, descending_bitset, descending_list,
not_descending_list, D);
}
}
}

for (j = 0; j < n; j++) {
u = samples[j];
for (k = j + 1; k < n; k++) {
v = samples[k];
D[j * n + k] += (double) count_mutations_on_path(
u, v, tree.parent, nodes_time, mutations_per_node);
}
}
ret = tsk_tree_next(&tree);
if (ret < 0) {
goto out;
Expand All @@ -6708,7 +6741,11 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
ret = 0;
out:
tsk_tree_free(&tree);
tsk_safe_free(mutations_per_node);
tsk_safe_free(sample_index_map);
tsk_safe_free(descending_bitset);
tsk_safe_free(descending_list);
tsk_safe_free(not_descending_list);
tsk_safe_free(stack);
return ret;
}

Expand Down

0 comments on commit 73dc0cf

Please sign in to comment.