diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index cceb11d6fd..63b7292322 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -175,6 +175,97 @@ verify_individual_nodes(tsk_treeseq_t *ts) } } +static void +verify_tree_pos(const tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *tree_parents) +{ + int ret; + const tsk_size_t N = tsk_treeseq_get_num_nodes(ts); + const tsk_id_t *edges_parent = ts->tables->edges.parent; + const tsk_id_t *edges_child = ts->tables->edges.child; + tsk_tree_position_t tree_pos; + tsk_id_t *known_parent; + tsk_id_t *parent = tsk_malloc(N * sizeof(*parent)); + tsk_id_t u, index, j, e; + bool valid; + + CU_ASSERT_FATAL(parent != NULL); + + ret = tsk_tree_position_init(&tree_pos, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (u = 0; u < (tsk_id_t) N; u++) { + parent[u] = TSK_NULL; + } + + for (index = 0; index < (tsk_id_t) num_trees; index++) { + known_parent = tree_parents + N * (tsk_size_t) index; + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_TRUE(valid); + CU_ASSERT_EQUAL(index, tree_pos.index); + + for (j = tree_pos.out.start; j < tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + + for (j = tree_pos.in.start; j < tree_pos.in.stop; j++) { + e = tree_pos.in.order[j]; + parent[edges_child[e]] = edges_parent[e]; + } + + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], known_parent[u]); + } + } + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FALSE(valid); + for (j = tree_pos.out.start; j < tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], TSK_NULL); + } + + for (index = (tsk_id_t) num_trees - 1; index >= 0; index--) { + known_parent = tree_parents + N * (tsk_size_t) index; + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_TRUE(valid); + CU_ASSERT_EQUAL(index, tree_pos.index); + + for (j = tree_pos.out.start; j > tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + + for (j = tree_pos.in.start; j > tree_pos.in.stop; j--) { + CU_ASSERT_FATAL(j >= 0); + e = tree_pos.in.order[j]; + parent[edges_child[e]] = edges_parent[e]; + } + + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], known_parent[u]); + } + } + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FALSE(valid); + for (j = tree_pos.out.start; j > tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], TSK_NULL); + } + + tsk_tree_position_free(&tree_pos); + tsk_safe_free(parent); +} + static void verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) { @@ -233,6 +324,8 @@ verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(ts), breakpoints[j]); tsk_tree_free(&tree); + + verify_tree_pos(ts, num_trees, parents); } static tsk_tree_t * @@ -5233,6 +5326,65 @@ test_single_tree_tracked_samples(void) tsk_tree_free(&tree); } +static void +test_single_tree_tree_pos(void) +{ + tsk_treeseq_t ts; + tsk_tree_position_t tree_pos; + bool valid; + int ret; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + + ret = tsk_tree_position_init(&tree_pos, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FATAL(valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.left, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.right, 1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.stop, 6); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.order, ts.tables->indexes.edge_insertion_order); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_removal_order); + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FATAL(!valid); + + tsk_tree_position_print_state(&tree_pos, _devnull); + + CU_ASSERT_EQUAL_FATAL(tree_pos.index, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 6); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_removal_order); + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FATAL(valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.left, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.right, 1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.stop, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.order, ts.tables->indexes.edge_removal_order); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_insertion_order); + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FATAL(!valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.index, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_insertion_order); + + tsk_tree_position_free(&tree_pos); + tsk_treeseq_free(&ts); +} + /*======================================================= * Multi tree tests. *======================================================*/ @@ -8185,6 +8337,7 @@ main(int argc, char **argv) { "test_single_tree_map_mutations_internal_samples", test_single_tree_map_mutations_internal_samples }, { "test_single_tree_tracked_samples", test_single_tree_tracked_samples }, + { "test_single_tree_tree_pos", test_single_tree_tree_pos }, /* Multi tree tests */ { "test_simple_multi_tree", test_simple_multi_tree }, diff --git a/c/tskit/haplotype_matching.c b/c/tskit/haplotype_matching.c index b942da18d6..d6fdfd7f46 100644 --- a/c/tskit/haplotype_matching.c +++ b/c/tskit/haplotype_matching.c @@ -209,7 +209,7 @@ int tsk_ls_hmm_free(tsk_ls_hmm_t *self) { tsk_tree_free(&self->tree); - tsk_diff_iter_free(&self->diffs); + tsk_tree_position_free(&self->tree_pos); tsk_safe_free(self->recombination_rate); tsk_safe_free(self->mutation_rate); tsk_safe_free(self->recombination_rate); @@ -248,9 +248,8 @@ tsk_ls_hmm_reset(tsk_ls_hmm_t *self) tsk_memset(self->transition_parent, 0xff, self->max_transitions * sizeof(*self->transition_parent)); - /* This is safe because we've already zero'd out the memory. */ - tsk_diff_iter_free(&self->diffs); - ret = tsk_diff_iter_init_from_ts(&self->diffs, self->tree_sequence, false); + tsk_tree_position_free(&self->tree_pos); + ret = tsk_tree_position_init(&self->tree_pos, self->tree_sequence, 0); if (ret != 0) { goto out; } @@ -306,21 +305,20 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) int ret = 0; tsk_id_t *restrict parent = self->parent; tsk_id_t *restrict T_index = self->transition_index; + const tsk_id_t *restrict edges_child = self->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = self->tree_sequence->tables->edges.parent; tsk_value_transition_t *restrict T = self->transitions; - tsk_edge_list_node_t *record; - tsk_edge_list_t records_out, records_in; - tsk_edge_t edge; - double left, right; - tsk_id_t u; + tsk_id_t u, c, p, j, e; tsk_value_transition_t *vt; - ret = tsk_diff_iter_next(&self->diffs, &left, &right, &records_out, &records_in); - if (ret < 0) { - goto out; - } + tsk_tree_position_next(&self->tree_pos); + tsk_bug_assert(self->tree_pos.index != -1); + tsk_bug_assert(self->tree_pos.index == self->tree.index); - for (record = records_out.head; record != NULL; record = record->next) { - u = record->edge.child; + for (j = self->tree_pos.out.start; j < self->tree_pos.out.stop; j++) { + e = self->tree_pos.out.order[j]; + c = edges_child[e]; + u = c; if (T_index[u] == TSK_NULL) { /* Ensure the subtree we're detaching has a transition at the root */ while (T_index[u] == TSK_NULL) { @@ -328,25 +326,27 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) tsk_bug_assert(u != TSK_NULL); } tsk_bug_assert(self->num_transitions < self->max_transitions); - T_index[record->edge.child] = (tsk_id_t) self->num_transitions; - T[self->num_transitions].tree_node = record->edge.child; + T_index[c] = (tsk_id_t) self->num_transitions; + T[self->num_transitions].tree_node = c; T[self->num_transitions].value = T[T_index[u]].value; self->num_transitions++; } - parent[record->edge.child] = TSK_NULL; + parent[c] = TSK_NULL; } - for (record = records_in.head; record != NULL; record = record->next) { - edge = record->edge; - parent[edge.child] = edge.parent; - u = edge.parent; - if (parent[edge.parent] == TSK_NULL) { + for (j = self->tree_pos.in.start; j < self->tree_pos.in.stop; j++) { + e = self->tree_pos.in.order[j]; + c = edges_child[e]; + p = edges_parent[e]; + parent[c] = p; + u = p; + if (parent[p] == TSK_NULL) { /* Grafting onto a new root. */ - if (T_index[record->edge.parent] == TSK_NULL) { - T_index[edge.parent] = (tsk_id_t) self->num_transitions; + if (T_index[p] == TSK_NULL) { + T_index[p] = (tsk_id_t) self->num_transitions; tsk_bug_assert(self->num_transitions < self->max_transitions); - T[self->num_transitions].tree_node = edge.parent; - T[self->num_transitions].value = T[T_index[edge.child]].value; + T[self->num_transitions].tree_node = p; + T[self->num_transitions].value = T[T_index[c]].value; self->num_transitions++; } } else { @@ -356,18 +356,17 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) } tsk_bug_assert(u != TSK_NULL); } - tsk_bug_assert(T_index[u] != -1 && T_index[edge.child] != -1); - if (T[T_index[u]].value == T[T_index[edge.child]].value) { - vt = &T[T_index[edge.child]]; + tsk_bug_assert(T_index[u] != -1 && T_index[c] != -1); + if (T[T_index[u]].value == T[T_index[c]].value) { + vt = &T[T_index[c]]; /* Mark the value transition as unusued */ vt->value = -1; vt->tree_node = TSK_NULL; - T_index[edge.child] = TSK_NULL; + T_index[c] = TSK_NULL; } } ret = tsk_ls_hmm_remove_dead_roots(self); -out: return ret; } diff --git a/c/tskit/haplotype_matching.h b/c/tskit/haplotype_matching.h index 46631fb086..e4d82bef03 100644 --- a/c/tskit/haplotype_matching.h +++ b/c/tskit/haplotype_matching.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -98,7 +98,10 @@ typedef struct _tsk_ls_hmm_t { tsk_size_t num_nodes; /* state */ tsk_tree_t tree; - tsk_diff_iter_t diffs; + /* NOTE: this tree_position will be redundant once we integrate the top-level + * tree class with this. + */ + tsk_tree_position_t tree_pos; tsk_id_t *parent; /* The probability value transitions on the tree */ tsk_value_transition_t *transitions; diff --git a/c/tskit/trees.c b/c/tskit/trees.c index dac3ac154b..8a3d0afc95 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3553,6 +3553,164 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag return ret; } +/* ======================================================== * + * tree_position + * ======================================================== */ + +static void +tsk_tree_position_set_null(tsk_tree_position_t *self) +{ + self->index = -1; + self->interval.left = 0; + self->interval.right = 0; +} + +int +tsk_tree_position_init(tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, + tsk_flags_t TSK_UNUSED(options)) +{ + memset(self, 0, sizeof(*self)); + self->tree_sequence = tree_sequence; + tsk_tree_position_set_null(self); + return 0; +} + +int +tsk_tree_position_free(tsk_tree_position_t *TSK_UNUSED(self)) +{ + return 0; +} + +int +tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out) +{ + fprintf(out, "Tree position state\n"); + fprintf(out, "index = %d\n", (int) self->index); + fprintf( + out, "out = start=%d\tstop=%d\n", (int) self->out.start, (int) self->out.stop); + fprintf( + out, "in = start=%d\tstop=%d\n", (int) self->in.start, (int) self->in.stop); + return 0; +} + +bool +tsk_tree_position_next(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double left; + + if (self->index == -1) { + self->interval.right = 0; + self->in.stop = 0; + self->out.stop = 0; + self->direction = TSK_DIR_FORWARD; + } + + if (self->direction == TSK_DIR_FORWARD) { + left_current_index = self->in.stop; + right_current_index = self->out.stop; + } else { + left_current_index = self->out.stop + 1; + right_current_index = self->in.stop + 1; + } + + left = self->interval.right; + + j = right_current_index; + self->out.start = j; + while (j < M && right_coords[right_order[j]] == left) { + j++; + } + self->out.stop = j; + self->out.order = right_order; + + j = left_current_index; + self->in.start = j; + while (j < M && left_coords[left_order[j]] == left) { + j++; + } + self->in.stop = j; + self->in.order = left_order; + + self->direction = TSK_DIR_FORWARD; + self->index++; + if (self->index == num_trees) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = left; + self->interval.right = breakpoints[self->index + 1]; + } + return self->index != -1; +} + +bool +tsk_tree_position_prev(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const double sequence_length = tables->sequence_length; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double right; + + if (self->index == -1) { + self->index = num_trees; + self->interval.left = sequence_length; + self->in.stop = M - 1; + self->out.stop = M - 1; + self->direction = TSK_DIR_REVERSE; + } + + if (self->direction == TSK_DIR_REVERSE) { + left_current_index = self->out.stop; + right_current_index = self->in.stop; + } else { + left_current_index = self->in.stop - 1; + right_current_index = self->out.stop - 1; + } + + right = self->interval.left; + + j = left_current_index; + self->out.start = j; + while (j >= 0 && left_coords[left_order[j]] == right) { + j--; + } + self->out.stop = j; + self->out.order = left_order; + + j = right_current_index; + self->in.start = j; + while (j >= 0 && right_coords[right_order[j]] == right) { + j--; + } + self->in.stop = j; + self->in.order = right_order; + + self->index--; + self->direction = TSK_DIR_REVERSE; + if (self->index == -1) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = breakpoints[self->index]; + self->interval.right = right; + } + return self->index != -1; +} + /* ======================================================== * * Tree * ======================================================== */ @@ -5946,25 +6104,29 @@ update_kc_subtree_state( } static int -update_kc_incremental(tsk_tree_t *self, kc_vectors *kc, tsk_edge_list_t *edges_out, - tsk_edge_list_t *edges_in, tsk_size_t *depths) +update_kc_incremental( + tsk_tree_t *tree, kc_vectors *kc, tsk_tree_position_t *tree_pos, tsk_size_t *depths) { int ret = 0; - tsk_edge_list_node_t *record; - tsk_edge_t *e; - tsk_id_t u; + tsk_id_t u, v, e, j; double root_time, time; - const double *times = self->tree_sequence->tables->nodes.time; + const double *restrict times = tree->tree_sequence->tables->nodes.time; + const tsk_id_t *restrict edges_child = tree->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = tree->tree_sequence->tables->edges.parent; + + tsk_bug_assert(tree_pos->index == tree->index); + tsk_bug_assert(tree_pos->interval.left == tree->interval.left); + tsk_bug_assert(tree_pos->interval.right == tree->interval.right); /* Update state of detached subtrees */ - for (record = edges_out->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; + for (j = tree_pos->out.stop - 1; j >= tree_pos->out.start; j--) { + e = tree_pos->out.order[j]; + u = edges_child[e]; depths[u] = 0; - if (self->parent[u] == TSK_NULL) { - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); + if (tree->parent[u] == TSK_NULL) { + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); if (ret != 0) { goto out; } @@ -5972,25 +6134,25 @@ update_kc_incremental(tsk_tree_t *self, kc_vectors *kc, tsk_edge_list_t *edges_o } /* Propagate state change down into reattached subtrees. */ - for (record = edges_in->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; + for (j = tree_pos->in.stop - 1; j >= tree_pos->in.start; j--) { + e = tree_pos->in.order[j]; + u = edges_child[e]; + v = edges_parent[e]; - tsk_bug_assert(depths[e->child] == 0); - depths[u] = depths[e->parent] + 1; + tsk_bug_assert(depths[u] == 0); + depths[u] = depths[v] + 1; - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); if (ret != 0) { goto out; } - if (tsk_tree_is_sample(self, u)) { - time = tsk_tree_get_branch_length_unsafe(self, u); - update_kc_vectors_single_sample(self->tree_sequence, kc, u, time); + if (tsk_tree_is_sample(tree, u)) { + time = tsk_tree_get_branch_length_unsafe(tree, u); + update_kc_vectors_single_sample(tree->tree_sequence, kc, u, time); } } - out: return ret; } @@ -6006,19 +6168,18 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, const tsk_treeseq_t *treeseqs[2] = { self, other }; tsk_tree_t trees[2]; kc_vectors kcs[2]; - tsk_diff_iter_t diff_iters[2]; - tsk_edge_list_t edges_out[2]; - tsk_edge_list_t edges_in[2]; + /* TODO the tree_pos here is redundant because we should be using this interally + * in the trees to do the advancing. Once we have converted the tree over to using + * tree_pos internally, we can get rid of these tree_pos variables and use + * the values stored in the trees themselves */ + tsk_tree_position_t tree_pos[2]; tsk_size_t *depths[2]; - double t0_left, t0_right, t1_left, t1_right; int ret = 0; for (i = 0; i < 2; i++) { tsk_memset(&trees[i], 0, sizeof(trees[i])); - tsk_memset(&diff_iters[i], 0, sizeof(diff_iters[i])); + tsk_memset(&tree_pos[i], 0, sizeof(tree_pos[i])); tsk_memset(&kcs[i], 0, sizeof(kcs[i])); - tsk_memset(&edges_out[i], 0, sizeof(edges_out[i])); - tsk_memset(&edges_in[i], 0, sizeof(edges_in[i])); depths[i] = NULL; } @@ -6033,7 +6194,7 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_init_from_ts(&diff_iters[i], treeseqs[i], false); + ret = tsk_tree_position_init(&tree_pos[i], treeseqs[i], 0); if (ret != 0) { goto out; } @@ -6060,11 +6221,10 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); + tsk_tree_position_next(&tree_pos[0]); + tsk_bug_assert(tree_pos[0].index == 0); + + ret = update_kc_incremental(&trees[0], &kcs[0], &tree_pos[0], depths[0]); if (ret != 0) { goto out; } @@ -6073,37 +6233,37 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[1], &t1_left, &t1_right, &edges_out[1], &edges_in[1]); - tsk_bug_assert(ret == TSK_TREE_OK); + tsk_tree_position_next(&tree_pos[1]); + tsk_bug_assert(tree_pos[1].index != -1); - ret = update_kc_incremental( - &trees[1], &kcs[1], &edges_out[1], &edges_in[1], depths[1]); + ret = update_kc_incremental(&trees[1], &kcs[1], &tree_pos[1], depths[1]); if (ret != 0) { goto out; } - while (t0_right < t1_right) { - span = t0_right - left; + tsk_bug_assert(trees[0].interval.left == tree_pos[0].interval.left); + tsk_bug_assert(trees[0].interval.right == tree_pos[0].interval.right); + tsk_bug_assert(trees[1].interval.left == tree_pos[1].interval.left); + tsk_bug_assert(trees[1].interval.right == tree_pos[1].interval.right); + while (trees[0].interval.right < trees[1].interval.right) { + span = trees[0].interval.right - left; total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; - left = t0_right; + left = trees[0].interval.right; ret = tsk_tree_next(&trees[0]); tsk_bug_assert(ret == TSK_TREE_OK); ret = check_kc_distance_tree_inputs(&trees[0]); if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); + tsk_tree_position_next(&tree_pos[0]); + tsk_bug_assert(tree_pos[0].index != -1); + ret = update_kc_incremental(&trees[0], &kcs[0], &tree_pos[0], depths[0]); if (ret != 0) { goto out; } } - span = t1_right - left; - left = t1_right; + span = trees[1].interval.right - left; + left = trees[1].interval.right; total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; } if (ret != 0) { @@ -6114,7 +6274,7 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, out: for (i = 0; i < 2; i++) { tsk_tree_free(&trees[i]); - tsk_diff_iter_free(&diff_iters[i]); + tsk_tree_position_free(&tree_pos[i]); kc_vectors_free(&kcs[i]); tsk_safe_free(depths[i]); } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index b36a38c31f..95c66a6ac7 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1739,6 +1739,40 @@ bool tsk_tree_equals(const tsk_tree_t *self, const tsk_tree_t *other); int tsk_diff_iter_init_from_ts( tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); +/* Temporarily putting this here to avoid problems with doxygen. Will need to + * move up the file later when it gets incorporated into the tsk_tree_t object. + */ +typedef struct { + tsk_id_t index; + struct { + double left; + double right; + } interval; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } in; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } out; + tsk_id_t left_current_index; + tsk_id_t right_current_index; + int direction; + const tsk_treeseq_t *tree_sequence; +} tsk_tree_position_t; + +int tsk_tree_position_init( + tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); +int tsk_tree_position_free(tsk_tree_position_t *self); +int tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out); +bool tsk_tree_position_next(tsk_tree_position_t *self); +bool tsk_tree_position_prev(tsk_tree_position_t *self); +int tsk_tree_position_seek_forward(tsk_tree_position_t *self, tsk_id_t index); +int tsk_tree_position_seek_backward(tsk_tree_position_t *self, tsk_id_t index); + #ifdef __cplusplus } #endif diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py new file mode 100644 index 0000000000..961f0810f7 --- /dev/null +++ b/python/tests/test_tree_positioning.py @@ -0,0 +1,470 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tests for tree iterator schemes. Mostly used to develop the incremental +iterator infrastructure. +""" +import msprime +import numpy as np +import pytest + +import tests +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +class StatefulTree: + """ + Just enough functionality to mimic the low-level tree implementation + for testing of forward/backward moving. + """ + + def __init__(self, ts): + self.ts = ts + self.tree_pos = tsutil.TreePosition(ts) + self.parent = [-1 for _ in range(ts.num_nodes)] + + def __str__(self): + s = f"parent: {self.parent}\nposition:\n" + for line in str(self.tree_pos).splitlines(): + s += f"\t{line}\n" + return s + + def assert_equal(self, other): + assert self.parent == other.parent + assert self.tree_pos.index == other.tree_pos.index + assert self.tree_pos.interval == other.tree_pos.interval + + def next(self): # NOQA: A003 + valid = self.tree_pos.next() + if valid: + for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): + e = self.tree_pos.out_range.order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop): + e = self.tree_pos.in_range.order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + return valid + + def prev(self): + valid = self.tree_pos.prev() + if valid: + for j in range( + self.tree_pos.out_range.start, self.tree_pos.out_range.stop, -1 + ): + e = self.tree_pos.out_range.order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range( + self.tree_pos.in_range.start, self.tree_pos.in_range.stop, -1 + ): + e = self.tree_pos.in_range.order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + return valid + + def iter_forward(self, index): + while self.tree_pos.index != index: + self.next() + + def seek_forward(self, index): + old_left, old_right = self.tree_pos.interval + self.tree_pos.seek_forward(index) + left, right = self.tree_pos.interval + # print() + # print("Current interval:", old_left, old_right) + # print("New interval:", left, right) + # print("index:", index, "out_range:", self.tree_pos.out_range) + for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): + e = self.tree_pos.out_range.order[j] + e_left = self.ts.edges_left[e] + # We only need to remove an edge if it's in the current tree, which + # can only happen if the edge's left coord is < the current tree's + # right coordinate. + if e_left < old_right: + c = self.ts.edges_child[e] + assert self.parent[c] != -1 + self.parent[c] = -1 + assert e_left < left + # print("index:", index, "in_range:", self.tree_pos.in_range) + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop): + e = self.tree_pos.in_range.order[j] + if self.ts.edges_left[e] <= left < self.ts.edges_right[e]: + # print("keep", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) + # print( + # "INSERT:", + # self.ts.edge(e), + # self.ts.nodes_time[self.ts.edges_parent[e]], + # ) + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + else: + a = self.tree_pos.in_range.start + b = self.tree_pos.in_range.stop + # The first and last indexes in the range should always be valid + # for the tree. + assert a < j < b - 1 + # print("skip", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) + + def seek_backward(self, index): + # TODO + while self.tree_pos.index != index: + self.prev() + + def iter_backward(self, index): + while self.tree_pos.index != index: + self.prev() + + +def check_iters_forward(ts): + alg_t_output = tsutil.algorithm_T(ts) + lib_tree = tskit.Tree(ts) + tree_pos = tsutil.TreePosition(ts) + sample_count = np.zeros(ts.num_nodes, dtype=int) + sample_count[ts.samples()] = 1 + parent1 = [-1 for _ in range(ts.num_nodes)] + i = 0 + lib_tree.next() + while tree_pos.next(): + out_times = [] + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop): + e = tree_pos.out_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + out_times.append(ts.nodes_time[p]) + parent1[c] = -1 + in_times = [] + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop): + e = tree_pos.in_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + in_times.append(ts.nodes_time[p]) + parent1[c] = p + # We must visit the edges in *increasing* time order on the way in, + # and *decreasing* order on the way out. Otherwise we get quadratic + # behaviour for algorithms that need to propagate changes up to the + # root. + assert out_times == sorted(out_times, reverse=True) + assert in_times == sorted(in_times) + + interval, parent2 = next(alg_t_output) + assert list(interval) == list(tree_pos.interval) + assert parent1 == parent2 + + assert lib_tree.index == i + assert list(lib_tree.interval) == list(interval) + assert list(lib_tree.parent_array[:-1]) == parent1 + + lib_tree.next() + i += 1 + assert i == ts.num_trees + assert lib_tree.index == -1 + assert next(alg_t_output, None) is None + + +def check_iters_back(ts): + alg_t_output = [ + (list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts) + ] + i = len(alg_t_output) - 1 + + lib_tree = tskit.Tree(ts) + tree_pos = tsutil.TreePosition(ts) + parent1 = [-1 for _ in range(ts.num_nodes)] + + lib_tree.last() + + while tree_pos.prev(): + # print(tree_pos.out_range) + out_times = [] + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, -1): + e = tree_pos.out_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + out_times.append(ts.nodes_time[p]) + parent1[c] = -1 + in_times = [] + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, -1): + e = tree_pos.in_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + in_times.append(ts.nodes_time[p]) + parent1[c] = p + + # We must visit the edges in *increasing* time order on the way in, + # and *decreasing* order on the way out. Otherwise we get quadratic + # behaviour for algorithms that need to propagate changes up to the + # root. + assert out_times == sorted(out_times, reverse=True) + assert in_times == sorted(in_times) + + interval, parent2 = alg_t_output[i] + assert list(interval) == list(tree_pos.interval) + assert parent1 == parent2 + + assert lib_tree.index == i + assert list(lib_tree.interval) == list(interval) + assert list(lib_tree.parent_array[:-1]) == parent1 + + lib_tree.prev() + i -= 1 + + assert lib_tree.index == -1 + assert i == -1 + + +def check_forward_back_sweep(ts): + alg_t_output = [ + (list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts) + ] + for j in range(ts.num_trees - 1): + tree = StatefulTree(ts) + # Seek forward to j + k = 0 + while k <= j: + tree.next() + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert list(tree.tree_pos.interval) == interval + assert parent == tree.parent + k += 1 + k = j + # And back to zero + while k >= 0: + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert list(tree.tree_pos.interval) == interval + assert parent == tree.parent + tree.prev() + k -= 1 + + +class TestDirectionSwitching: + # 2.00┊ ┊ 4 ┊ 4 ┊ 4 ┊ + # ┊ ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ ┏┻━┓ ┊ + # 1.00┊ 3 ┊ ┃ 3 ┊ 3 ┃ ┊ 3 ┃ ┊ + # ┊ ┏━╋━┓ ┊ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 2 ┊ 0 1 2 ┊ 0 2 1 ┊ 0 1 2 ┊ + # 0 1 2 3 4 + # index 0 1 2 3 + def ts(self): + return tsutil.all_trees_ts(3) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_to_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_from_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_backward_to_next(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_backward(index) + tree1.next() + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index + 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index + 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_next_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_next_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.seek_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_from_null(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + def test_seek_forward_next_null(self): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(3) + tree1.next() + assert tree1.tree_pos.index == -1 + assert list(tree1.tree_pos.interval) == [0, 0] + + +class TestSeeking: + @tests.cached_example + def ts(self): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + return ts + + @pytest.mark.parametrize("index", range(26)) + def test_seek_forward_from_null(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", range(1, 26)) + def test_seek_forward_from_first(self, index): + tree1 = StatefulTree(self.ts()) + tree1.next() + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", range(1, 26)) + def test_seek_last_from_index(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.iter_forward(index) + tree1.seek_forward(ts.num_trees - 1) + tree2 = StatefulTree(ts) + tree2.prev() + tree1.assert_equal(tree2) + + +class TestAllTreesTs: + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_forward_full(self, n): + ts = tsutil.all_trees_ts(n) + check_iters_forward(ts) + + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_back_full(self, n): + ts = tsutil.all_trees_ts(n) + check_iters_back(ts) + + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_forward_back(self, n): + ts = tsutil.all_trees_ts(n) + check_forward_back_sweep(ts) + + +class TestManyTreesSimulationExample: + @tests.cached_example + def ts(self): + ts = msprime.sim_ancestry( + 10, sequence_length=1000, recombination_rate=0.1, random_seed=1234 + ) + assert ts.num_trees > 250 + return ts + + @pytest.mark.parametrize("index", [1, 5, 10, 50, 100]) + def test_seek_forward_from_null(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("num_trees", [1, 5, 10, 50, 100]) + def test_seek_forward_from_mid(self, num_trees): + ts = self.ts() + start_index = ts.num_trees // 2 + dest_index = min(start_index + num_trees, ts.num_trees - 1) + tree1 = StatefulTree(ts) + tree1.iter_forward(start_index) + tree1.seek_forward(dest_index) + tree2 = StatefulTree(ts) + tree2.iter_forward(dest_index) + tree1.assert_equal(tree2) + + def test_forward_full(self): + check_iters_forward(self.ts()) + + def test_back_full(self): + check_iters_back(self.ts()) + + +class TestSuiteExamples: + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_forward_full(self, ts): + check_iters_forward(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_back_full(self, ts): + check_iters_back(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_forward_from_null(self, ts): + index = ts.num_trees // 2 + tree1 = StatefulTree(ts) + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_forward_from_first(self, ts): + index = ts.num_trees - 1 + tree1 = StatefulTree(ts) + tree1.next() + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 34334e9be0..b86a159274 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -24,11 +24,13 @@ A collection of utilities to edit and construct tree sequences. """ import collections +import dataclasses import functools import json import random import string import struct +import typing import msprime import numpy as np @@ -1713,6 +1715,196 @@ def iterate(self): left = right +FORWARD = 1 +REVERSE = -1 + + +@dataclasses.dataclass +class Interval: + left: float + right: float + + def __iter__(self): + yield self.left + yield self.right + + +@dataclasses.dataclass +class EdgeRange: + start: int + stop: int + order: typing.List + + +class TreePosition: + def __init__(self, ts): + self.ts = ts + self.index = -1 + self.direction = 0 + self.interval = Interval(0, 0) + self.in_range = EdgeRange(0, 0, None) + self.out_range = EdgeRange(0, 0, None) + + def __str__(self): + s = f"index: {self.index}\ninterval: {self.interval}\n" + s += f"direction: {self.direction}\n" + s += f"in_range: {self.in_range}\n" + s += f"out_range: {self.out_range}\n" + return s + + def set_null(self): + self.index = -1 + self.interval.left = 0 + self.interval.right = 0 + + def next(self): # NOQA: A003 + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + self.interval.right = 0 + self.out_range.stop = 0 + self.in_range.stop = 0 + self.direction = FORWARD + + if self.direction == FORWARD: + left_current_index = self.in_range.stop + right_current_index = self.out_range.stop + else: + left_current_index = self.out_range.stop + 1 + right_current_index = self.in_range.stop + 1 + + left = self.interval.right + + j = right_current_index + self.out_range.start = j + while j < M and right_coords[right_order[j]] == left: + j += 1 + self.out_range.stop = j + self.out_range.order = right_order + + j = left_current_index + self.in_range.start = j + while j < M and left_coords[left_order[j]] == left: + j += 1 + self.in_range.stop = j + self.in_range.order = left_order + + self.direction = FORWARD + self.index += 1 + if self.index == self.ts.num_trees: + self.set_null() + else: + self.interval.left = left + self.interval.right = breakpoints[self.index + 1] + return self.index != -1 + + def prev(self): + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + + if self.index == -1: + self.index = self.ts.num_trees + self.interval.left = self.ts.sequence_length + self.in_range.stop = M - 1 + self.out_range.stop = M - 1 + self.direction = REVERSE + + if self.direction == REVERSE: + left_current_index = self.out_range.stop + right_current_index = self.in_range.stop + else: + left_current_index = self.in_range.stop - 1 + right_current_index = self.out_range.stop - 1 + + right = self.interval.left + + j = left_current_index + self.out_range.start = j + while j >= 0 and left_coords[left_order[j]] == right: + j -= 1 + self.out_range.stop = j + self.out_range.order = left_order + + j = right_current_index + self.in_range.start = j + while j >= 0 and right_coords[right_order[j]] == right: + j -= 1 + self.in_range.stop = j + self.in_range.order = right_order + + self.direction = REVERSE + self.index -= 1 + if self.index == -1: + self.set_null() + else: + self.interval.left = breakpoints[self.index] + self.interval.right = right + return self.index != -1 + + def seek_forward(self, index): + # NOTE this is still in development and not fully tested. + assert index >= self.index and index < self.ts.num_trees + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + self.interval.right = 0 + self.out_range.stop = 0 + self.in_range.stop = 0 + self.direction = FORWARD + + if self.direction == FORWARD: + left_current_index = self.in_range.stop + right_current_index = self.out_range.stop + else: + left_current_index = self.out_range.stop + 1 + right_current_index = self.in_range.stop + 1 + + self.direction = FORWARD + left = breakpoints[index] + + # The range of edges we need consider for removal starts + # at the current right index and ends at the first edge + # where the right coordinate is equal to the new tree's + # left coordinate. + j = right_current_index + self.out_range.start = j + # TODO This could be done with binary search + while j < M and right_coords[right_order[j]] <= left: + j += 1 + self.out_range.stop = j + + # The range of edges we need to consider for the new tree + # must have right coordinate > left + j = left_current_index + while j < M and right_coords[left_order[j]] <= left: + j += 1 + self.in_range.start = j + # TODO this could be done with a binary search + while j < M and left_coords[left_order[j]] <= left: + j += 1 + self.in_range.stop = j + + self.interval.left = left + self.interval.right = breakpoints[index + 1] + self.out_range.order = right_order + self.in_range.order = left_order + self.index = index + + def mean_descendants(ts, reference_sets): """ Returns the mean number of nodes from the specified reference sets