diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index 11606a42f2..7302ba3341 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -1,3 +1,13 @@ +-------- +UPCOMING +-------- + +**Features** + +- Add the `tsk_treeseq_extend_edges` method that can compress a tree sequence + by extending edges into adjacent trees and thus creating unary nodes in those + trees (:user:`petrelharp`, :user:`hfr1tze`, :user:`avabamf`, :pr:`2651`). + -------------------- [1.1.2] - 2023-05-17 -------------------- @@ -24,7 +34,7 @@ (:user:`jeromekelleher`, :issue:`2662`, :pr:`2663`). - Guarantee that unfiltered tables are not written to unnecessarily - during simplify (:user:`jeromekelleher` :pr:`2619`). + during simplify (:user:`jeromekelleher`, :pr:`2619`). - Add `x_table_keep_rows` methods to provide efficient in-place table subsetting (:user:`jeromekelleher`, :pr:`2700`). diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 63b7292322..a9fee8a8ff 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -8212,6 +8212,165 @@ test_split_edges_errors(void) tsk_treeseq_free(&ts); } +static void +test_extend_edges_simple(void) +{ + int ret; + tsk_treeseq_t ts, ets; + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 2.0 -1 -1\n"; + const char *edges = "0 10 2 0\n" + "0 10 2 1\n"; + const char *sites = "0.0 0\n" + "1.0 0\n"; + const char *mutations = "0 0 1 -1 0.5\n" + "1 1 1 -1 0.5\n"; + + tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, ets.tables, 0)); + tsk_treeseq_free(&ts); + + tsk_treeseq_free(&ets); +} + +static void +test_extend_edges_errors(void) +{ + int ret; + tsk_treeseq_t ts, ets; + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 2.0 -1 -1\n"; + const char *edges = "0 10 2 0\n" + "0 10 2 1\n"; + const char *sites = "0.0 0\n" + "1.0 0\n"; + const char *mutations = "0 0 1 -1 0.5\n" + "1 1 1 -1 0.5\n"; + const char *mutations_no_time = "0 0 1 -1\n" + "1 1 1 -1\n"; + // left, right, node source, dest, time + const char *migrations = "0 10 0 0 1 0.5\n" + "0 10 0 1 0 1.5\n"; + + tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + ret = tsk_treeseq_extend_edges(&ts, -2, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EXTEND_EDGES_BAD_MAXITER); + tsk_treeseq_free(&ts); + + tsk_treeseq_from_text( + &ts, 10, nodes, edges, migrations, sites, mutations, NULL, NULL, 0); + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATIONS_NOT_SUPPORTED); + tsk_treeseq_free(&ts); + + tsk_treeseq_from_text( + &ts, 10, nodes, edges, NULL, sites, mutations_no_time, NULL, NULL, 0); + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME); + tsk_treeseq_free(&ts); + + tsk_treeseq_free(&ets); +} + +static void +assert_equal_except_edges_and_mutation_nodes( + const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2) +{ + tsk_table_collection_t t1, t2; + int ret; + + ret = tsk_table_collection_copy(ts1->tables, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_table_collection_copy(ts2->tables, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_memset(t1.mutations.node, 0, t1.mutations.num_rows * sizeof(*t1.mutations.node)); + tsk_memset(t2.mutations.node, 0, t2.mutations.num_rows * sizeof(*t2.mutations.node)); + + tsk_edge_table_clear(&t1.edges); + tsk_edge_table_clear(&t2.edges); + + CU_ASSERT_TRUE(tsk_table_collection_equals(&t1, &t2, 0)); + + tsk_table_collection_free(&t1); + tsk_table_collection_free(&t2); +} + +static void +test_extend_edges(void) +{ + int ret, max_iter; + tsk_treeseq_t ts, ets; + /* 7 and 8 should be extended to the whole sequence + + 6 6 6 6 + +-+-+ +-+-+ +-+-+ +-+-+ + | | 7 | | 8 | | + | | ++-+ | | +-++ | | + 4 5 4 | | 4 | 5 4 5 + +++ +++ +++ | | | | +++ +++ +++ + 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + */ + + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 3.0 -1 -1\n" + "0 2.0 -1 -1\n" + "0 2.0 -1 -1\n"; + // l, r, p, c + const char *edges = "0 10 4 0\n" + "0 5 4 1\n" + "7 10 4 1\n" + "0 2 5 2\n" + "5 10 5 2\n" + "0 2 5 3\n" + "5 10 5 3\n" + "2 5 7 2\n" + "2 5 7 4\n" + "5 7 8 1\n" + "5 7 8 5\n" + "2 5 6 3\n" + "0 2 6 4\n" + "5 10 6 4\n" + "0 2 6 5\n" + "7 10 6 5\n" + "2 5 6 7\n" + "5 7 6 8\n"; + const char *sites = "0.0 0\n" + "9.0 0\n"; + const char *mutations = "0 4 1 -1 2.5\n" + "0 4 2 0 1.5\n" + "1 5 1 -1 2.5\n" + "1 5 2 2 1.5\n"; + + tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + for (max_iter = 1; max_iter < 10; max_iter++) { + ret = tsk_treeseq_extend_edges(&ts, max_iter, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_equal_except_edges_and_mutation_nodes(&ts, &ets); + CU_ASSERT_TRUE(ets.tables->edges.num_rows >= 12); + tsk_treeseq_free(&ets); + } + + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(ets.tables->nodes.num_rows, 9); + CU_ASSERT_EQUAL_FATAL(ets.tables->edges.num_rows, 12); + tsk_treeseq_free(&ets); + + tsk_treeseq_free(&ts); +} + static void test_init_take_ownership_no_edge_metadata(void) { @@ -8431,6 +8590,9 @@ main(int argc, char **argv) { "test_split_edges_no_populations", test_split_edges_no_populations }, { "test_split_edges_populations", test_split_edges_populations }, { "test_split_edges_errors", test_split_edges_errors }, + { "test_extend_edges_simple", test_extend_edges_simple }, + { "test_extend_edges_errors", test_extend_edges_errors }, + { "test_extend_edges", test_extend_edges }, { "test_init_take_ownership_no_edge_metadata", test_init_take_ownership_no_edge_metadata }, { NULL, NULL }, diff --git a/c/tskit/core.c b/c/tskit/core.c index 44be9740fc..08c2269b59 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -330,6 +330,11 @@ tsk_strerror_internal(int err) "values for any single site. " "(TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN)"; break; + case TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME: + ret = "Some mutation times are marked 'unknown' for a method that requires " + "no unknown times. (Use compute_mutation_times to add times?) " + "(TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME)"; + break; /* Migration errors */ case TSK_ERR_UNSORTED_MIGRATIONS: @@ -615,6 +620,11 @@ tsk_strerror_internal(int err) "if an individual has nodes from more than one time. " "(TSK_ERR_INDIVIDUAL_TIME_MISMATCH)"; break; + + case TSK_ERR_EXTEND_EDGES_BAD_MAXITER: + ret = "Maximum number of iterations must be positive. " + "(TSK_ERR_EXTEND_EDGES_BAD_MAXITER)"; + break; } return ret; } diff --git a/c/tskit/core.h b/c/tskit/core.h index a10f4376a5..52caa836d6 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -500,6 +500,11 @@ the edge on which it occurs, and wasn't TSK_UNKNOWN_TIME. A single site had a mixture of known mutation times and TSK_UNKNOWN_TIME */ #define TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN -509 +/** +Some mutations have TSK_UNKNOWN_TIME in an algorithm where that's +disallowed (use compute_mutation_times?). +*/ +#define TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME -510 /** @} */ /** @@ -865,6 +870,16 @@ An individual had nodes from more than one time */ #define TSK_ERR_INDIVIDUAL_TIME_MISMATCH -1704 /** @} */ + +/** +@defgroup EXTEND_EDGES_ERROR_GROUP Extend edges errors. +@{ +*/ +/** +Maximum iteration number (max_iter) must be positive. +*/ +#define TSK_ERR_EXTEND_EDGES_BAD_MAXITER -1800 +/** @} */ // clang-format on /* This bit is 0 for any errors originating from kastore */ diff --git a/c/tskit/trees.c b/c/tskit/trees.c index f51a4a0333..56991b5f40 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -7644,3 +7644,400 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, tsk_safe_free(sample_index_map); return ret; } + +/* ======================================================== * + * Extend edges + * ======================================================== */ + +typedef struct _edge_list_t { + tsk_id_t edge; + // the `extended` flags records whether we have decided to extend + // this entry to the current tree? + bool extended; + struct _edge_list_t *next; +} edge_list_t; + +static int +extend_edges_append_entry( + edge_list_t **head, edge_list_t **tail, tsk_blkalloc_t *heap, tsk_id_t edge) +{ + int ret = 0; + edge_list_t *x = NULL; + + x = tsk_blkalloc_get(heap, sizeof(*x)); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + x->edge = edge; + x->extended = false; + x->next = NULL; + + if (*tail == NULL) { + *head = x; + } else { + (*tail)->next = x; + } + *tail = x; +out: + return ret; +} + +static void +remove_unextended(edge_list_t **head, edge_list_t **tail) +{ + edge_list_t *px, *x; + + px = *head; + while (px != NULL && !px->extended) { + px = px->next; + } + *head = px; + if (px != NULL) { + px->extended = false; + x = px->next; + while (x != NULL) { + if (x->extended) { + x->extended = false; + px->next = x; + px = x; + } + x = x->next; + } + px->next = NULL; + } + *tail = px; +} + +static int +tsk_treeseq_extend_edges_iter( + const tsk_treeseq_t *self, int direction, tsk_edge_table_t *edges) +{ + // Note: this modifies the edge table, but it does this by (a) removing + // some edges, and (b) extending left/right endpoints of others, + // while keeping order the same, and so this maintains sortedness + // (so, there is no need to sort afterwards). + int ret = 0; + tsk_id_t tj; + tsk_id_t e, e_out, e_in; + tsk_id_t c, p, p_in; + tsk_blkalloc_t edge_list_heap; + double *near_side, *far_side; + edge_list_t *edges_in_head, *edges_in_tail; + edge_list_t *edges_out_head, *edges_out_tail; + edge_list_t *ex_out, *ex_in; + double there, left, right; + bool forwards = (direction == TSK_DIR_FORWARD); + tsk_tree_position_t tree_pos; + bool valid; + const tsk_table_collection_t *tables = self->tables; + const tsk_size_t num_nodes = tables->nodes.num_rows; + const tsk_size_t num_edges = tables->edges.num_rows; + tsk_id_t *degree = tsk_calloc(num_nodes, sizeof(*degree)); + tsk_id_t *out_parent = tsk_malloc(num_nodes * sizeof(*out_parent)); + tsk_bool_t *keep = tsk_calloc(num_edges, sizeof(*keep)); + bool *not_sample = tsk_malloc(num_nodes * sizeof(*not_sample)); + + tsk_memset(&edge_list_heap, 0, sizeof(edge_list_heap)); + tsk_memset(&tree_pos, 0, sizeof(tree_pos)); + + if (keep == NULL || out_parent == NULL || degree == NULL || not_sample == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tsk_memset(out_parent, 0xff, num_nodes * sizeof(*out_parent)); + + ret = tsk_blkalloc_init(&edge_list_heap, 8192); + if (ret != 0) { + goto out; + } + ret = tsk_tree_position_init(&tree_pos, self, 0); + if (ret != 0) { + goto out; + } + ret = tsk_edge_table_copy(&tables->edges, edges, TSK_NO_INIT); + if (ret != 0) { + goto out; + } + + for (tj = 0; tj < (tsk_id_t) tables->nodes.num_rows; tj++) { + not_sample[tj] = ((tables->nodes.flags[tj] & TSK_NODE_IS_SAMPLE) == 0); + } + + if (forwards) { + near_side = edges->left; + far_side = edges->right; + } else { + near_side = edges->right; + far_side = edges->left; + } + edges_in_head = NULL; + edges_in_tail = NULL; + edges_out_head = NULL; + edges_out_tail = NULL; + e_out = 0; // only to avoid an 'maybe uninitialized' compile warning + + if (forwards) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); + } + + while (valid) { + left = tree_pos.interval.left; + right = tree_pos.interval.right; + there = forwards ? right : left; + + // remove entries that aren't being extended/postponed + // and update out_parent + for (ex_out = edges_out_head; ex_out != NULL; ex_out = ex_out->next) { + e = ex_out->edge; + out_parent[edges->child[e]] = TSK_NULL; + } + remove_unextended(&edges_in_head, &edges_in_tail); + remove_unextended(&edges_out_head, &edges_out_tail); + for (ex_out = edges_out_head; ex_out != NULL; ex_out = ex_out->next) { + e = ex_out->edge; + out_parent[edges->child[e]] = edges->parent[e]; + } + + for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += direction) { + e = tree_pos.out.order[tj]; + if (out_parent[edges->child[e]] == TSK_NULL) { + // add edge to pending_out + ret = extend_edges_append_entry( + &edges_out_head, &edges_out_tail, &edge_list_heap, e); + if (ret != 0) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + out_parent[edges->child[e]] = edges->parent[e]; + } + } + for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += direction) { + e = tree_pos.in.order[tj]; + // add edge to pending_in + ret = extend_edges_append_entry( + &edges_in_head, &edges_in_tail, &edge_list_heap, e); + if (ret != 0) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + for (ex_out = edges_out_head; ex_out != NULL; ex_out = ex_out->next) { + e_out = ex_out->edge; + degree[edges->parent[e_out]] -= 1; + degree[edges->child[e_out]] -= 1; + tsk_bug_assert(out_parent[edges->child[e_out]] == edges->parent[e_out]); + } + for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e_in = ex_in->edge; + degree[edges->parent[e_in]] += 1; + degree[edges->child[e_in]] += 1; + } + + for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e_in = ex_in->edge; + // check whether the parent-child relationship exists in the + // sub-forest of edges to be removed: + // out_parent[p] != -1 only when it is the bottom of an edge to be + // removed, and degree[p] == 0 only if it is not in the new tree + c = edges->child[e_in]; + p = out_parent[c]; + p_in = edges->parent[e_in]; + while ((p != TSK_NULL) && (degree[p] == 0) && (p != p_in) && not_sample[p]) { + p = out_parent[p]; + } + if (p == p_in) { + // we can extend! + // But, we might have passed the interval that a + // postponed edge in covers, in which case + // we should skip postponing the edge in + if (far_side[e_in] != there) { + ex_in->extended = true; + } + near_side[e_in] = there; + while (c != p) { + for (ex_out = edges_out_head; ex_out != NULL; + ex_out = ex_out->next) { + e_out = ex_out->edge; + if (edges->child[e_out] == c) { + break; + } + } + tsk_bug_assert(edges->child[e_out] == c); + ex_out->extended = true; + far_side[e_out] = there; + // amend degree: the intermediate + // nodes have 2 edges instead of 0 + tsk_bug_assert(degree[c] == 0 || c == edges->child[e_in]); + if (degree[c] == 0) { + degree[c] = 2; + } + c = out_parent[c]; + } + } + } + if (forwards) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); + } + } + + for (e = 0; e < (tsk_id_t) num_edges; e++) { + keep[e] = edges->left[e] < edges->right[e]; + } + ret = tsk_edge_table_keep_rows(edges, keep, 0, NULL); +out: + tsk_blkalloc_free(&edge_list_heap); + tsk_tree_position_free(&tree_pos); + tsk_safe_free(degree); + tsk_safe_free(out_parent); + tsk_safe_free(keep); + tsk_safe_free(not_sample); + return ret; +} + +static int +tsk_treeseq_slide_mutation_nodes_up( + const tsk_treeseq_t *self, tsk_mutation_table_t *mutations) +{ + int ret = 0; + double t; + tsk_id_t c, p, next_mut; + const tsk_table_collection_t *tables = self->tables; + const tsk_size_t num_nodes = tables->nodes.num_rows; + double *sites_position = tables->sites.position; + double *nodes_time = tables->nodes.time; + tsk_tree_t tree; + + ret = tsk_tree_init(&tree, self, TSK_NO_SAMPLE_COUNTS); + if (ret != 0) { + goto out; + } + + next_mut = 0; + for (ret = tsk_tree_first(&tree); ret == TSK_TREE_OK; ret = tsk_tree_next(&tree)) { + while (next_mut < (tsk_id_t) mutations->num_rows + && sites_position[mutations->site[next_mut]] < tree.interval.right) { + t = mutations->time[next_mut]; + if (tsk_is_unknown_time(t)) { + ret = TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME; + goto out; + } + c = mutations->node[next_mut]; + tsk_bug_assert(c < (tsk_id_t) num_nodes); + p = tree.parent[c]; + while (p != TSK_NULL && nodes_time[p] <= t) { + c = p; + p = tree.parent[c]; + } + tsk_bug_assert(nodes_time[c] <= t); + mutations->node[next_mut] = c; + next_mut++; + } + } + if (ret != 0) { + goto out; + } + +out: + tsk_tree_free(&tree); + + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, + tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) +{ + int ret = 0; + tsk_table_collection_t tables; + tsk_treeseq_t ts; + int iter, j; + tsk_size_t last_num_edges; + const int direction[] = { TSK_DIR_FORWARD, TSK_DIR_REVERSE }; + + tsk_memset(&tables, 0, sizeof(tables)); + tsk_memset(&ts, 0, sizeof(ts)); + tsk_memset(output, 0, sizeof(*output)); + + if (max_iter <= 0) { + ret = TSK_ERR_EXTEND_EDGES_BAD_MAXITER; + goto out; + } + if (tsk_treeseq_get_num_migrations(self) != 0) { + ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + + /* Note: there is a fair bit of copying of table data in this implementation + * currently, as we create a new tree sequence for each iteration, which + * takes a full copy of the input tables. We could streamline this by + * adding a flag to treeseq_init which says "steal a reference to these + * tables and *don't* free them at the end". Then, we would only need + * one copy of the full tables, and could pass in a standalone edge + * table to use for in-place updating. + */ + ret = tsk_table_collection_copy(self->tables, &tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_table_clear(&tables.mutations); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init(&ts, &tables, 0); + if (ret != 0) { + goto out; + } + + last_num_edges = tsk_treeseq_get_num_edges(&ts); + for (iter = 0; iter < max_iter; iter++) { + for (j = 0; j < 2; j++) { + ret = tsk_treeseq_extend_edges_iter(&ts, direction[j], &tables.edges); + if (ret != 0) { + goto out; + } + /* We're done with the current ts now */ + tsk_treeseq_free(&ts); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { + goto out; + } + } + if (last_num_edges == tsk_treeseq_get_num_edges(&ts)) { + break; + } + last_num_edges = tsk_treeseq_get_num_edges(&ts); + } + + /* Remap mutation nodes */ + ret = tsk_mutation_table_copy( + &self->tables->mutations, &tables.mutations, TSK_NO_INIT); + if (ret != 0) { + goto out; + } + /* Note: to allow migrations we'd also have to do this same operation + * on the migration nodes; however it's a can of worms because the interval + * covering the migration might no longer make sense. */ + ret = tsk_treeseq_slide_mutation_nodes_up(&ts, &tables.mutations); + if (ret != 0) { + goto out; + } + tsk_treeseq_free(&ts); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { + goto out; + } + + /* Hand ownership of the tree sequence to the calling code */ + tsk_memcpy(output, &ts, sizeof(ts)); + tsk_memset(&ts, 0, sizeof(*output)); +out: + tsk_treeseq_free(&ts); + tsk_table_collection_free(&tables); + return ret; +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index a503b3e39b..7bc1e60092 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -891,6 +891,45 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, tsk_size_t num_samples, tsk_flags_t options, tsk_treeseq_t *output, tsk_id_t *node_map); +/** +@brief Extends edges + +Returns a modified tree sequence in which the span covered by ancestral nodes +is "extended" to regions of the genome according to the following rule: +If an ancestral segment corresponding to node `n` has parent `p` and +child `c` on some portion of the genome, and on an adjacent segment of +genome `p` is the immediate parent of `c`, then `n` is inserted into the +edge from `p` to `c`. This involves extending the span of the edges +from `p` to `n` and `n` to `c` and reducing the span of the edge from +`p` to `c`. Since the latter edge may be removed entirely, this process +reduces (or at least does not increase) the number of edges in the tree +sequence. The `node` of certain mutations may also be remapped; to do this +unambiguously we need to know mutation times. If mutations times are unknown, +use `tsk_table_collection_compute_mutation_times` first. + +The method will not affect any tables except the edge table, or the node +column in the mutation table. + +The method works by iterating over the genome to look for edges that can +be extended in this way; the maximum number of such iterations is +controlled by ``max_iter``. + +Since this may change which nodes are above + +@rst + +**Options**: None currently defined. +@endrst + +@param self A pointer to a tsk_treeseq_t object. +@param max_iter The maximum number of iterations over the tree sequence. +@param options Bitwise option flags. (UNUSED) +@param output A pointer to an uninitialised tsk_treeseq_t object. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_treeseq_extend_edges( + const tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output); + /** @} */ int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, diff --git a/docs/_config.yml b/docs/_config.yml index 474657635f..f65d071bbf 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -45,6 +45,7 @@ sphinx: html_theme: sphinx_book_theme html_theme_options: pygment_dark_style: monokai + navigation_with_keys: false pygments_style: monokai myst_enable_extensions: - colon_fence diff --git a/docs/c-api.rst b/docs/c-api.rst index bd8233ed6e..460ce74b58 100644 --- a/docs/c-api.rst +++ b/docs/c-api.rst @@ -796,6 +796,13 @@ Individual errors .. doxygengroup:: INDIVIDUAL_ERROR_GROUP :content-only: +------------------- +Extend edges errors +------------------- + +.. doxygengroup:: EXTEND_EDGES_ERROR_GROUP + :content-only: + .. _sec_c_api_examples: diff --git a/docs/python-api.md b/docs/python-api.md index a8236daadf..20a2d541b4 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -268,6 +268,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.trim TreeSequence.split_edges TreeSequence.decapitate + TreeSequence.extend_edges ``` (sec_python_api_tree_sequences_ibd)= diff --git a/python/.gitignore b/python/.gitignore index 1d3e405ad1..acdfa61981 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -3,3 +3,4 @@ *.egg-info build .*.swp +*/.ipynb_checkpoints diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 26d130b7e9..cb9dee04d2 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -2,6 +2,11 @@ [0.5.7] - 2023-XX-XX -------------------- +**Features** + +- Add ``TreeSequence.extend_edges`` method that extends ancestral haplotypes + using recombination information, leading to unary nodes in many trees and + fewer edges. (:user:`petrelharp`, :user:`hfr1tz3`, :user:`avabamf`, :pr:`2651`) -------------------- diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 431cf9afd2..ea42f98b5b 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8938,6 +8938,46 @@ TreeSequence_mean_descendants(TreeSequence *self, PyObject *args, PyObject *kwds return ret; } +static PyObject * +TreeSequence_extend_edges(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + int max_iter; + tsk_flags_t options = 0; + static char *kwlist[] = { "max_iter", NULL }; + TreeSequence *output = NULL; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "i", kwlist, &max_iter)) { + goto out; + } + + output = (TreeSequence *) _PyObject_New((PyTypeObject *) &TreeSequenceType); + if (output == NULL) { + goto out; + } + output->tree_sequence = PyMem_Malloc(sizeof(*output->tree_sequence)); + if (output->tree_sequence == NULL) { + PyErr_NoMemory(); + goto out; + } + + err = tsk_treeseq_extend_edges( + self->tree_sequence, max_iter, options, output->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) output; + output = NULL; +out: + Py_XDECREF(output); + return ret; +} + /* Error value returned from summary_func callback if an error occured. * This is chosen so that it is not a valid tskit error code and so can * never be mistaken for a different error */ @@ -10531,6 +10571,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_split_edges, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns a copy of this tree sequence edges split at time t" }, + { .ml_name = "extend_edges", + .ml_meth = (PyCFunction) TreeSequence_extend_edges, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Extends edges, creating unary nodes." }, { .ml_name = "has_reference_sequence", .ml_meth = (PyCFunction) TreeSequence_has_reference_sequence, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_extend_edges.py b/python/tests/test_extend_edges.py new file mode 100644 index 0000000000..7ee8c8f471 --- /dev/null +++ b/python/tests/test_extend_edges.py @@ -0,0 +1,654 @@ +import msprime +import numpy as np +import pytest + +import _tskit +import tests.test_wright_fisher as wf +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +def extend_edges(ts, max_iter=10): + tables = ts.dump_tables() + mutations = tables.mutations.copy() + tables.mutations.clear() + + last_num_edges = ts.num_edges + for _ in range(max_iter): + for forwards in [True, False]: + edges = _extend(ts, forwards=forwards) + tables.edges.replace_with(edges) + tables.build_index() + ts = tables.tree_sequence() + if ts.num_edges == last_num_edges: + break + else: + last_num_edges = ts.num_edges + + tables = ts.dump_tables() + mutations = _slide_mutation_nodes_up(ts, mutations) + tables.mutations.replace_with(mutations) + ts = tables.tree_sequence() + + return ts + + +def _slide_mutation_nodes_up(ts, mutations): + # adjusts mutations' nodes to place each mutation on the correct edge given + # their time; requires mutation times be nonmissing and the mutation times + # be >= their nodes' times. + + assert np.all(~tskit.is_unknown_time(mutations.time)), "times must be known" + new_nodes = mutations.node.copy() + + mut = 0 + for tree in ts.trees(): + _, right = tree.interval + while ( + mut < mutations.num_rows and ts.sites_position[mutations.site[mut]] < right + ): + t = mutations.time[mut] + c = mutations.node[mut] + p = tree.parent(c) + assert ts.nodes_time[c] <= t + while p != -1 and ts.nodes_time[p] <= t: + c = p + p = tree.parent(c) + assert ts.nodes_time[c] <= t + if p != -1: + assert t < ts.nodes_time[p] + new_nodes[mut] = c + mut += 1 + + # in C the node column can be edited in place + new_mutations = mutations.copy() + new_mutations.clear() + for mut, n in zip(mutations, new_nodes): + new_mutations.append(mut.replace(node=n)) + + return new_mutations + + +def _extend(ts, forwards=True): + # `degree` will record the degree of each node in the tree we'd get if + # we removed all `out` edges and added all `in` edges + degree = np.full(ts.num_nodes, 0, dtype="int") + # `out_parent` will record the sub-forest of edges-to-be-removed + out_parent = np.full(ts.num_nodes, -1, dtype="int") + keep = np.full(ts.num_edges, True, dtype=bool) + not_sample = [not n.is_sample() for n in ts.nodes()] + + edges = ts.tables.edges.copy() + + # "here" will be left if fowards else right; + # and "there" is the other + new_left = edges.left.copy() + new_right = edges.right.copy() + if forwards: + direction = 1 + # in C we can just modify these in place, but in + # python they are (silently) immutable + near_side = new_left + far_side = new_right + else: + direction = -1 + near_side = new_right + far_side = new_left + edges_out = [] + edges_in = [] + + tree_pos = tsutil.TreePosition(ts) + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + while valid: + left, right = tree_pos.interval + there = right if forwards else left + + # Clear out non-extended or postponed edges: + # Note: maintaining out_parent is a bit tricky, because + # if an edge from p->c has been extended, entirely replacing + # another edge from p'->c, then both edges may be in edges_out, + # and we only want to include the *first* one. + for e, _ in edges_out: + out_parent[edges.child[e]] = -1 + tmp = [] + for e, x in edges_out: + if x: + tmp.append([e, False]) + edges_out = tmp + tmp = [] + for e, x in edges_in: + if x: + tmp.append([e, False]) + edges_in = tmp + + for e, _ in edges_out: + out_parent[edges.child[e]] = edges.parent[e] + + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, direction): + e = tree_pos.out_range.order[j] + if out_parent[edges.child[e]] == -1: + edges_out.append([e, False]) + out_parent[edges.child[e]] = edges.parent[e] + + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, direction): + e = tree_pos.in_range.order[j] + edges_in.append([e, False]) + + for e, _ in edges_out: + degree[edges.parent[e]] -= 1 + degree[edges.child[e]] -= 1 + for e, _ in edges_in: + degree[edges.parent[e]] += 1 + degree[edges.child[e]] += 1 + + # validate out_parent array + for c, p in enumerate(out_parent): + foundit = False + for e, _ in edges_out: + if edges.child[e] == c: + assert edges.parent[e] == p + foundit = True + break + assert foundit == (p != -1) + + assert np.all(degree >= 0) + for ex_in in edges_in: + e_in = ex_in[0] + # check whether the parent-child relationship exists + # in the sub-forest of edges to be removed: + # out_parent[p] != -1 only when it is the bottom of + # an edge to be removed, + # and degree[p] == 0 only if it is not in the new tree + c = edges.child[e_in] + p = out_parent[c] + p_in = edges.parent[e_in] + while p != tskit.NULL and degree[p] == 0 and p != p_in and not_sample[p]: + p = out_parent[p] + if p == p_in: + # we might have passed the interval that a + # postponed edge in covers, in which case + # we should skip it + if far_side[e_in] != there: + ex_in[1] = True + near_side[e_in] = there + while c != p: + # just loop over the edges out until we find the right entry + for ex_out in edges_out: + e_out = ex_out[0] + if edges.child[e_out] == c: + break + assert edges.child[e_out] == c + ex_out[1] = True + far_side[e_out] = there + # amend degree: the intermediate + # nodes have 2 edges instead of 0 + assert degree[c] == 0 or c == edges.child[e_in] + if degree[c] == 0: + degree[c] = 2 + c = out_parent[c] + + # end of loop, next tree + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + + for j in range(edges.num_rows): + left = new_left[j] + right = new_right[j] + if left < right: + edges[j] = edges[j].replace(left=left, right=right) + else: + keep[j] = False + edges.keep_rows(keep) + return edges + + +class TestExtendEdges: + """ + Test the 'extend edges' method + """ + + def verify_extend_edges(self, ts, max_iter=10): + # This can still fail for various weird examples: + # for instance, if adjacent trees have + # a <- b <- c <- d and a <- d (where say b was + # inserted in an earlier pass), then b and c + # won't be extended + + ets = ts.extend_edges(max_iter=max_iter) + assert np.all(ts.genotype_matrix() == ets.genotype_matrix()) + assert ts.num_samples == ets.num_samples + assert ts.num_nodes == ets.num_nodes + assert ts.num_edges >= ets.num_edges + t = ts.simplify().tables + et = ets.simplify().tables + t.assert_equals(et, ignore_provenance=True) + old_edges = {} + for e in ts.edges(): + k = (e.parent, e.child) + if k not in old_edges: + old_edges[k] = [] + old_edges[k].append((e.left, e.right)) + + for e in ets.edges(): + # e should be in old_edges, + # but with modified limits: + # USUALLY overlapping limits, but + # not necessarily after more than one pass + k = (e.parent, e.child) + assert k in old_edges + if max_iter == 1: + overlaps = False + for left, right in old_edges[k]: + if (left <= e.right) and (right >= e.left): + overlaps = True + assert overlaps + + if max_iter > 1: + chains = [] + for _, tt, ett in ts.coiterate(ets): + this_chains = [] + for a in tt.nodes(): + assert a in ett.nodes() + b = tt.parent(a) + if b != tskit.NULL: + c = tt.parent(b) + if c != tskit.NULL: + this_chains.append((a, b, c)) + assert b in ett.nodes() + # the relationship a <- b should still be in the tree + p = a + while p != tskit.NULL and p != b: + p = ett.parent(p) + assert p == b + chains.append(this_chains) + + extended_ac = {} + not_extended_ac = {} + extended_ab = {} + not_extended_ab = {} + for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): + for j in (k - 1, k + 1): + if j < 0 or j >= len(chains): + continue + else: + this_chains = chains[j] + for a, b, c in this_chains: + if ( + a in tt.nodes() + and tt.parent(a) == c + and b not in tt.nodes() + ): + # the relationship a <- b <- c should still be in the tree, + # although maybe they aren't direct parent-offspring + # UNLESS we've got an ambiguous case, where on the opposite + # side of the interval a chain a <- b' <- c got extended + # into the region OR b got inserted into another chain + assert a in ett.nodes() + assert c in ett.nodes() + if b not in ett.nodes(): + if (a, c) not in not_extended_ac: + not_extended_ac[(a, c)] = [] + not_extended_ac[(a, c)].append(interval) + else: + if (a, c) not in extended_ac: + extended_ac[(a, c)] = [] + extended_ac[(a, c)].append(interval) + p = a + while p != tskit.NULL and p != b: + p = ett.parent(p) + if p != b: + if (a, b) not in not_extended_ab: + not_extended_ab[(a, b)] = [] + not_extended_ab[(a, b)].append(interval) + else: + if (a, b) not in extended_ab: + extended_ab[(a, b)] = [] + extended_ab[(a, b)].append(interval) + while p != tskit.NULL and p != c: + p = ett.parent(p) + assert p == c + for a, c in not_extended_ac: + # check that a <- ... <- c has been extended somewhere + # although not necessarily from an adjacent segment + assert (a, c) in extended_ac + for interval in not_extended_ac[(a, c)]: + ett = ets.at(interval.left) + assert ett.parent(a) != c + for k in not_extended_ab: + assert k in extended_ab + for interval in not_extended_ab[k]: + assert interval in extended_ab[k] + + # finally, compare C version to python version + py_ts = extend_edges(ts, max_iter=max_iter) + py_et = py_ts.dump_tables() + et = ets.dump_tables() + et.assert_equals(py_et) + + def test_runs(self): + ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) + self.verify_extend_edges(ts) + + def test_migrations_disallowed(self): + ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) + tables = ts.dump_tables() + tables.populations.add_row() + tables.populations.add_row() + tables.migrations.add_row(0, 1, 0, 0, 1, 0) + ts = tables.tree_sequence() + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" + ): + _ = ts.extend_edges() + + def test_unknown_times(self): + ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) + tables = ts.dump_tables() + tables.mutations.clear() + for mut in ts.mutations(): + tables.mutations.append(mut.replace(time=tskit.UNKNOWN_TIME)) + ts = tables.tree_sequence() + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME" + ): + _ = ts.extend_edges() + + def test_max_iter(self): + ts = msprime.simulate(5, random_seed=126) + with pytest.raises(_tskit.LibraryError, match="positive"): + ets = ts.extend_edges(max_iter=0) + with pytest.raises(_tskit.LibraryError, match="positive"): + ets = ts.extend_edges(max_iter=-1) + ets = ts.extend_edges(max_iter=1) + et = ets.extend_edges(max_iter=1).dump_tables() + eet = ets.extend_edges(max_iter=2).dump_tables() + eet.assert_equals(et) + + def get_simple_ex(self, samples=None): + # An example where you need to go forwards *and* backwards: + # 7 and 8 should be extended to the whole sequence + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # | | 7 | | 8 | | + # | | ++-+ | | +-++ | | + # 4 5 4 | | 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + # + # Result: + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # 7 8 7 8 7 8 7 8 + # | | ++-+ | | +-++ | | + # 4 5 4 | 5 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 1.0, + 5: 1.0, + 6: 3.0, + 7: 2.0, + 8: 2.0, + } + # (p, c, l, r) + edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (7, 2, 2, 5), + (7, 4, 2, 5), + (8, 1, 5, 7), + (8, 5, 5, 7), + (6, 3, 2, 5), + (6, 4, 0, 2), + (6, 4, 5, 10), + (6, 5, 0, 2), + (6, 5, 7, 10), + (6, 7, 2, 5), + (6, 8, 5, 7), + ] + # here is the 'right answer' (but note only with the default args) + extended_edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 10), + (7, 2, 2, 5), + (7, 4, 0, 10), + (8, 1, 5, 7), + (8, 5, 0, 10), + (6, 7, 0, 10), + (6, 8, 0, 10), + ] + tables = tskit.TableCollection(sequence_length=10) + if samples is None: + samples = [0, 1, 2, 3] + for n, t in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + tables.edges.clear() + for p, c, l, r in extended_edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = tables.tree_sequence() + assert ts.num_edges == 18 + assert ets.num_edges == 12 + return ts, ets + + def test_simple_ex(self): + ts, right_ets = self.get_simple_ex() + ets = ts.extend_edges() + ets.tables.assert_equals(right_ets.tables) + self.verify_extend_edges(ts) + + def test_internal_samples(self): + # Now we should have the same but not extend 5 (where * is): + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # 7 * 7 * 7 8 7 8 + # | | ++-+ | | +-++ | | + # 4 5 4 | * 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + # + # (p, c, l, r) + edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (7, 2, 2, 5), + (7, 4, 0, 10), + (8, 1, 5, 7), + (8, 5, 5, 10), + (6, 3, 2, 5), + (6, 5, 0, 2), + (6, 7, 0, 10), + (6, 8, 5, 10), + ] + ts, _ = self.get_simple_ex(samples=[0, 1, 2, 3, 5]) + tables = ts.dump_tables() + tables.edges.clear() + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = ts.extend_edges() + ets.tables.assert_equals(tables) + # validation doesn't work with internal, incomplete samples + # (and it would be a pain to make it work) + # self.verify_extend_edges(ts) + + def test_iterative_example(self): + # Here is the full tree; extend edges should be able to + # recover all unary nodes after simplification: + # + # 9 9 9 9 + # +-+-+ +--+--+ +---+---+ +-+-+--+ + # 8 | 8 | 8 | | 8 | | | + # | | +-+-+ | | | | | | | | + # 7 | | 7 | | 7 | | | | 7 + # +-+-+ | | +-++ | | +-++ | | | | | + # 6 | | | | 6 | | | 6 | | | | 6 + # +-++ | | | | | | | | | | | | | | + # 1 0 2 3 1 2 0 3 1 2 0 3 1 2 3 0 + # +++ +++ +++ +++ + # 4 5 4 5 4 5 4 5 + # + samples = [0, 1, 2, 3, 4, 5] + node_times = [1, 1, 1, 1, 0, 0, 2, 3, 4, 5] + # (p, c, l, r) + edges = [ + (0, 4, 0, 10), + (0, 5, 0, 10), + (6, 0, 0, 10), + (6, 1, 0, 3), + (7, 2, 0, 7), + (7, 6, 0, 10), + (8, 1, 3, 10), + (8, 7, 0, 5), + (9, 2, 7, 10), + (9, 3, 0, 10), + (9, 7, 5, 10), + (9, 8, 0, 10), + ] + tables = tskit.TableCollection(sequence_length=10) + for n, t in enumerate(node_times): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + sts = ts.simplify() + assert ts.num_edges == 12 + assert sts.num_edges == 16 + tables.assert_equals(sts.extend_edges().tables, ignore_provenance=True) + + def test_very_simple(self): + samples = [0] + node_times = [0, 1, 2, 3] + # (p, c, l, r) + edges = [ + (1, 0, 0, 1), + (2, 0, 1, 2), + (2, 1, 0, 1), + (3, 0, 2, 3), + (3, 2, 0, 2), + ] + correct_edges = [ + (1, 0, 0, 3), + (2, 1, 0, 3), + (3, 2, 0, 3), + ] + tables = tskit.TableCollection(sequence_length=3) + for n, t in enumerate(node_times): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + ets = ts.extend_edges() + for _, t, et in ts.coiterate(ets): + print("----") + print(t.draw(format="ascii")) + print(et.draw(format="ascii")) + etables = ets.tables + correct_tables = etables.copy() + etables.edges.clear() + for p, c, l, r in correct_edges: + etables.edges.add_row(parent=p, child=c, left=l, right=r) + etables.assert_equals(correct_tables, ignore_provenance=True) + + def test_wright_fisher(self): + tables = wf.wf_sim(N=5, ngens=20, num_loci=100, deep_history=False, seed=3) + tables.sort() + tables.simplify() + ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888) + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + def test_wright_fisher_unsimplified(self): + tables = wf.wf_sim(N=6, ngens=22, num_loci=100, deep_history=False, seed=4) + tables.sort() + ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888) + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + def test_wright_fisher_with_history(self): + tables = wf.wf_sim(N=8, ngens=15, num_loci=100, deep_history=True, seed=5) + tables.sort() + tables.simplify() + ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888) + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + # This one fails sometimes but just because our verification can't handle + # figuring out what exactly should be the right answer in complex cases. + # + # def test_bigger_wright_fisher(self): + # tables = wf.wf_sim(N=50, ngens=15, deep_history=True, seed=6) + # tables.sort() + # tables.simplify() + # ts = tables.tree_sequence() + # self.verify_extend_edges(ts, max_iter=1) + # self.verify_extend_edges(ts, max_iter=200) + + +class TestExamples: + """ + Compare the ts method with local implementation. + """ + + def check(self, ts): + if np.any(tskit.is_unknown_time(ts.mutations_time)): + tables = ts.dump_tables() + tables.compute_mutation_times() + ts = tables.tree_sequence() + py_ts = extend_edges(ts) + lib_ts = ts.extend_edges() + lib_ts.tables.assert_equals(py_ts.tables) + assert np.all(ts.genotype_matrix() == lib_ts.genotype_matrix()) + sts = ts.simplify() + lib_sts = lib_ts.simplify() + lib_sts.tables.assert_equals(sts.tables, ignore_provenance=True) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_suite_examples_defaults(self, ts): + if ts.num_migrations == 0: + self.check(ts) + else: + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" + ): + _ = ts.extend_edges() + + @pytest.mark.parametrize("n", [3, 4, 5]) + def test_all_trees_ts(self, n): + ts = tsutil.all_trees_ts(n) + self.check(ts) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 8fc1e427a2..16676a240c 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1496,6 +1496,22 @@ def test_time_units(self): ts.load_tables(tables) assert ts.get_time_units() == value + def test_extend_edges_bad_args(self): + ts1 = self.get_example_tree_sequence(10) + with pytest.raises(TypeError): + ts1.extend_edges() + with pytest.raises(TypeError, match="an integer"): + ts1.extend_edges("sdf") + with pytest.raises(_tskit.LibraryError, match="positive"): + ts1.extend_edges(0) + with pytest.raises(_tskit.LibraryError, match="positive"): + ts1.extend_edges(-1) + tsm = self.get_example_migration_tree_sequence() + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" + ): + tsm.extend_edges(1) + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index df2fa2faf1..1db0333fbb 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6912,6 +6912,56 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None): tables.delete_older(time) return tables.tree_sequence() + def extend_edges(self, max_iter=10): + """ + Returns a new tree sequence in which the span covered by ancestral nodes + is "extended" to regions of the genome according to the following rule: + If an ancestral segment corresponding to node `n` has parent `p` and + child `c` on some portion of the genome, and on an adjacent segment of + genome `p` is the immediate parent of `c`, then `n` is inserted into the + edge from `p` to `c`. This involves extending the span of the edges + from `p` to `n` and `n` to `c` and reducing the span of the edge from + `p` to `c`. However, any edges whose child node is a sample will not + be modified. + + Since some edges may be removed entirely, this process reduces (or at + least does not increase) the number of edges in the tree sequence. + + *Note:* this is a somewhat experimental operation, and is probably not + what you are looking for. + + The method works by iterating over the genome to look for edges that can + be extended in this way; the maximum number of such iterations is + controlled by ``max_iter``. + + The rationale is that we know that `n` carries a portion of the segment + of ancestral genome inherited by `c` from `p`, and so likely carries + the *entire* inherited segment (since the implication otherwise would + be that distinct recombined segments were passed down separately from + `p` to `c`). + + If an edge that a mutation falls on is split by this operation, the + mutation's node may need to be moved. This is only unambiguous if the + mutation's time is known, so the method requires known mutation times. + See :meth:`.impute_unknown_mutations_time` if mutation times are + not known. + + The method will not affect the marginal trees (so, if the original tree + sequence was simplified, then following up with `simplify` will recover + the original tree sequence, possibly with edges in a different order). + It will also not affect the genotype matrix, or any of the tables other + than the edge table or the node column in the mutation table. + + :param int max_iters: The maximum number of iterations over the tree + sequence. Defaults to 10. + + :return: A new tree sequence with unary nodes extended. + :rtype: tskit.TreeSequence + """ + max_iter = int(max_iter) + ll_ts = self._ll_tree_sequence.extend_edges(max_iter) + return TreeSequence(ll_ts) + def subset( self, nodes,