diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 37e9d09f11..dcf4d2a69d 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -8414,7 +8414,7 @@ test_split_edges_errors(void) } static void -test_extend_edges_simple(void) +test_extend_haplotypes_simple(void) { int ret; tsk_treeseq_t ts, ets; @@ -8429,16 +8429,16 @@ test_extend_edges_simple(void) "1 1 1 -1 0.5\n"; tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + ret = tsk_treeseq_extend_haplotypes(&ts, 10, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, ets.tables, 0)); + CU_ASSERT_TRUE_FATAL(tsk_table_collection_equals(ts.tables, ets.tables, 0)); tsk_treeseq_free(&ts); tsk_treeseq_free(&ets); } static void -test_extend_edges_errors(void) +test_extend_haplotypes_errors(void) { int ret; tsk_treeseq_t ts, ets; @@ -8458,19 +8458,19 @@ test_extend_edges_errors(void) "0 10 0 1 0 1.5\n"; tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - ret = tsk_treeseq_extend_edges(&ts, -2, 0, &ets); + ret = tsk_treeseq_extend_haplotypes(&ts, -2, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EXTEND_EDGES_BAD_MAXITER); tsk_treeseq_free(&ts); tsk_treeseq_from_text( &ts, 10, nodes, edges, migrations, sites, mutations, NULL, NULL, 0); - ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + ret = tsk_treeseq_extend_haplotypes(&ts, 10, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATIONS_NOT_SUPPORTED); tsk_treeseq_free(&ts); tsk_treeseq_from_text( &ts, 10, nodes, edges, NULL, sites, mutations_no_time, NULL, NULL, 0); - ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + ret = tsk_treeseq_extend_haplotypes(&ts, 10, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME); tsk_treeseq_free(&ts); @@ -8503,17 +8503,21 @@ assert_equal_except_edges_and_mutation_nodes( } static void -test_extend_edges(void) +test_extend_haplotypes(void) { - int ret, max_iter; + int ret = 0; + int max_iter = 10; tsk_treeseq_t ts, ets; - /* 7 and 8 should be extended to the whole sequence + FILE *tmp = fopen(_tmp_file_name, "w"); + + /* 7 and 8 should be extended to the whole sequence; + * also 5 to the second tree (where x's are) 6 6 6 6 +-+-+ +-+-+ +-+-+ +-+-+ - | | 7 | | 8 | | + | | 7 x x 8 x x | | ++-+ | | +-++ | | - 4 5 4 | | 4 | 5 4 5 + 4 5 4 | x 4 | 5 4 5 +++ +++ +++ | | | | +++ +++ +++ 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 */ @@ -8550,26 +8554,144 @@ test_extend_edges(void) "9.0 0\n"; const char *mutations = "0 4 1 -1 2.5\n" "0 4 2 0 1.5\n" - "1 5 1 -1 2.5\n" - "1 5 2 2 1.5\n"; + "1 6 3 -1 3.5\n" + "1 5 1 2 2.5\n" + "1 5 2 3 1.5\n"; tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); for (max_iter = 1; max_iter < 10; max_iter++) { - ret = tsk_treeseq_extend_edges(&ts, max_iter, 0, &ets); + ret = tsk_treeseq_extend_haplotypes(&ts, max_iter, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_equal_except_edges_and_mutation_nodes(&ts, &ets); CU_ASSERT_TRUE(ets.tables->edges.num_rows >= 12); tsk_treeseq_free(&ets); } - ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + ret = tsk_treeseq_extend_haplotypes(&ts, max_iter, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(ets.tables->nodes.num_rows, 9); CU_ASSERT_EQUAL_FATAL(ets.tables->edges.num_rows, 12); + assert_equal_except_edges_and_mutation_nodes(&ts, &ets); + tsk_treeseq_free(&ets); + + tsk_set_debug_stream(tmp); + ret = tsk_treeseq_extend_haplotypes(&ts, max_iter, TSK_DEBUG, &ets); + tsk_set_debug_stream(stdout); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(ftell(tmp) > 0); + tsk_treeseq_free(&ets); + + fclose(tmp); + tsk_treeseq_free(&ts); +} + +static void +test_extend_haplotypes_conflicting_times(void) +{ + int ret; + int max_iter = 10; + tsk_treeseq_t ts, ets; + /* + 3.00┊ 3 ┊ 4 ┊ + ┊ ┃ ┊ ┃ ┊ + 2.00┊ ┃ ┊ 2 ┊ + ┊ ┃ ┊ ┃ ┊ + 1.00┊ 1 ┊ ┃ ┊ + ┊ ┃ ┊ ┃ ┊ + 0.00┊ 0 ┊ 0 ┊ + 0 2 4 + */ + + const char *nodes = "1 0.0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 2.0 -1 -1\n" + "0 3.0 -1 -1\n" + "0 3.0 -1 -1\n"; + // l, r, p, c + const char *edges = "0.0 2.0 1 0\n" + "2.0 4.0 2 0\n" + "0.0 2.0 3 1\n" + "2.0 4.0 4 2\n"; + + tsk_treeseq_from_text(&ts, 4, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ts.tables->edges.num_rows, 4); + + ret = tsk_treeseq_extend_haplotypes(&ts, max_iter, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, ets.tables, 0)); + tsk_treeseq_free(&ets); + + tsk_treeseq_free(&ts); +} + +static void +test_extend_haplotypes_new_edge(void) +{ + int ret; + int max_iter = 10; + tsk_treeseq_t ts, ets, ref_ts; + /* This is an example where new edges are added + * on both forwards and back passes + 4.00┊ ┊ 4 ┊ 4 ┊ 4 ┊ + ┊ ┊ ┃ ┊ ┃ ┊ ┃ ┊ + 3.00┊ 2 ┊ ┃ ┊ 2 ┊ 2 ┊ + ┊ ┃ ┊ ┃ ┊ ┃ ┊ ┃ ┊ + 2.00┊ ┃ ┊ 3 ┊ ┃ ┊ 3 ┊ + ┊ ┃ ┊ ┃ ┊ ┃ ┊ ┃ ┊ + 1.00┊ 1 ┊ ┃ ┊ ┃ ┊ ┃ ┊ + ┊ ┃ ┊ ┃ ┊ ┃ ┊ ┃ ┊ + 0.00┊ 0 ┊ 0 ┊ 0 ┊ 0 ┊ + 0 2 4 6 8 + */ + + const char *nodes = "1 0.0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 3.0 -1 -1\n" + "0 2.0 -1 -1\n" + "0 4.0 -1 -1\n"; + // l, r, p, c + const char *edges = "0.0 2.0 1 0\n" + "2.0 4.0 3 0\n" + "6.0 8.0 3 0\n" + "4.0 5.0 2 0\n" + "5.0 6.0 2 0\n" + "0.0 2.0 2 1\n" + "6.0 7.0 2 3\n" + "7.0 8.0 2 3\n" + "4.0 8.0 4 2\n" + "2.0 4.0 4 3\n"; + const char *ext_edges = "0.0 8.0 1 0\n" + "0.0 8.0 3 1\n" + "0.0 8.0 2 3\n" + "2.0 8.0 4 2\n"; + const char *sites = "3.0 0\n"; + // s, n , ds, t + const char *mutations = "0 4 5 -1 4.5\n" + "0 3 4 0 3.5\n" + "0 3 3 1 2.5\n" + "0 0 2 2 1.5\n" + "0 0 1 3 0.5\n"; + const char *ext_mutations = "0 4 5 -1 4.5\n" + "0 2 4 0 3.5\n" + "0 3 3 1 2.5\n" + "0 1 2 2 1.5\n" + "0 0 1 3 0.5\n"; + + tsk_treeseq_from_text(&ts, 8, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ts.tables->edges.num_rows, 10); + tsk_treeseq_from_text( + &ref_ts, 8, nodes, ext_edges, NULL, sites, ext_mutations, NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ref_ts.tables->edges.num_rows, 4); + + ret = tsk_treeseq_extend_haplotypes(&ts, max_iter, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_equal_except_edges_and_mutation_nodes(&ts, &ets); + CU_ASSERT_TRUE(tsk_table_collection_equals(ets.tables, ref_ts.tables, 0)); tsk_treeseq_free(&ets); tsk_treeseq_free(&ts); + tsk_treeseq_free(&ref_ts); } static void @@ -8793,9 +8915,12 @@ main(int argc, char **argv) { "test_split_edges_no_populations", test_split_edges_no_populations }, { "test_split_edges_populations", test_split_edges_populations }, { "test_split_edges_errors", test_split_edges_errors }, - { "test_extend_edges_simple", test_extend_edges_simple }, - { "test_extend_edges_errors", test_extend_edges_errors }, - { "test_extend_edges", test_extend_edges }, + { "test_extend_haplotypes_simple", test_extend_haplotypes_simple }, + { "test_extend_haplotypes_errors", test_extend_haplotypes_errors }, + { "test_extend_haplotypes", test_extend_haplotypes }, + { "test_extend_haplotypes_new_edge", test_extend_haplotypes_new_edge }, + { "test_extend_haplotypes_conflicting_times", + test_extend_haplotypes_conflicting_times }, { "test_init_take_ownership_no_edge_metadata", test_init_take_ownership_no_edge_metadata }, { NULL, NULL }, diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 8dca6d0720..87537ff052 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2024 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -784,7 +784,9 @@ tsk_treeseq_from_text(tsk_treeseq_t *ts, double sequence_length, const char *nod ret = tsk_treeseq_init(ts, &tables, TSK_TS_INIT_BUILD_INDEXES); /* tsk_treeseq_print_state(ts, stdout); */ - /* printf("ret = %s\n", tsk_strerror(ret)); */ + if (ret != 0) { + printf("\nret = %s\n", tsk_strerror(ret)); + } CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_table_collection_free(&tables); } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 0905992e5c..d214f1dce0 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -8521,32 +8521,44 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_sample_s } /* ======================================================== * - * Extend edges + * Extend haplotypes * ======================================================== */ typedef struct _edge_list_t { tsk_id_t edge; // the `extended` flags records whether we have decided to extend // this entry to the current tree? - bool extended; + int extended; struct _edge_list_t *next; } edge_list_t; -static int -extend_edges_append_entry( - edge_list_t **head, edge_list_t **tail, tsk_blkalloc_t *heap, tsk_id_t edge) -{ - int ret = 0; - edge_list_t *x = NULL; - - x = tsk_blkalloc_get(heap, sizeof(*x)); - if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; +static void +edge_list_print(edge_list_t **head, tsk_edge_table_t *edges, FILE *out) +{ + int n = 0; + edge_list_t *px; + fprintf(out, "Edge list:\n"); + for (px = *head; px != NULL; px = px->next) { + fprintf(out, " %d: %d (%d); ", n, (int) px->edge, px->extended); + if (px->edge >= 0 && edges != NULL) { + fprintf(out, "%d->%d on [%.1f, %.1f)", (int) edges->child[px->edge], + (int) edges->parent[px->edge], edges->left[px->edge], + edges->right[px->edge]); + } else { + fprintf(out, "(null)"); + } + fprintf(out, "\n"); + n += 1; } + fprintf(out, "length = %d\n", n); +} +static void +edge_list_append_entry( + edge_list_t **head, edge_list_t **tail, edge_list_t *x, tsk_id_t edge, int extended) +{ x->edge = edge; - x->extended = false; + x->extended = extended; x->next = NULL; if (*tail == NULL) { @@ -8555,8 +8567,6 @@ extend_edges_append_entry( (*tail)->next = x; } *tail = x; -out: - return ret; } static void @@ -8565,16 +8575,16 @@ remove_unextended(edge_list_t **head, edge_list_t **tail) edge_list_t *px, *x; px = *head; - while (px != NULL && !px->extended) { + while (px != NULL && px->extended == 0) { px = px->next; } *head = px; if (px != NULL) { - px->extended = false; + px->extended = 0; x = px->next; while (x != NULL) { - if (x->extended) { - x->extended = false; + if (x->extended > 0) { + x->extended = 0; px->next = x; px = x; } @@ -8585,75 +8595,479 @@ remove_unextended(edge_list_t **head, edge_list_t **tail) *tail = px; } +static void +edge_list_set_extended(edge_list_t **head, tsk_id_t edge_id) +{ + // finds the entry with edge 'edge_id' + // and sets its 'extended' flag to 1 + edge_list_t *px; + px = *head; + tsk_bug_assert(px != NULL); + while (px->edge != edge_id) { + px = px->next; + tsk_bug_assert(px != NULL); + } + tsk_bug_assert(px->edge == edge_id); + px->extended = 1; +} + static int -tsk_treeseq_extend_edges_iter( - const tsk_treeseq_t *self, int direction, tsk_edge_table_t *edges) +tsk_treeseq_slide_mutation_nodes_up( + const tsk_treeseq_t *self, tsk_mutation_table_t *mutations) { - // Note: this modifies the edge table, but it does this by (a) removing - // some edges, and (b) extending left/right endpoints of others, - // while keeping order the same, and so this maintains sortedness - // (so, there is no need to sort afterwards). int ret = 0; - tsk_id_t tj; - tsk_id_t e, e_out, e_in; - tsk_id_t c, p, p_in; - tsk_blkalloc_t edge_list_heap; - double *near_side, *far_side; - edge_list_t *edges_in_head, *edges_in_tail; - edge_list_t *edges_out_head, *edges_out_tail; - edge_list_t *ex_out, *ex_in; - double there, left, right; - bool forwards = (direction == TSK_DIR_FORWARD); - tsk_tree_position_t tree_pos; - bool valid; - const tsk_table_collection_t *tables = self->tables; - const tsk_size_t num_nodes = tables->nodes.num_rows; - const tsk_size_t num_edges = tables->edges.num_rows; - tsk_id_t *degree = tsk_calloc(num_nodes, sizeof(*degree)); - tsk_id_t *out_parent = tsk_malloc(num_nodes * sizeof(*out_parent)); - tsk_bool_t *keep = tsk_calloc(num_edges, sizeof(*keep)); - bool *not_sample = tsk_malloc(num_nodes * sizeof(*not_sample)); - - tsk_memset(&edge_list_heap, 0, sizeof(edge_list_heap)); - tsk_memset(&tree_pos, 0, sizeof(tree_pos)); + double t; + tsk_id_t c, p, next_mut; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const double *sites_position = self->tables->sites.position; + const double *nodes_time = self->tables->nodes.time; + tsk_tree_t tree; - if (keep == NULL || out_parent == NULL || degree == NULL || not_sample == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_tree_init(&tree, self, TSK_NO_SAMPLE_COUNTS); + if (ret != 0) { goto out; } - tsk_memset(out_parent, 0xff, num_nodes * sizeof(*out_parent)); - ret = tsk_blkalloc_init(&edge_list_heap, 8192); + next_mut = 0; + for (ret = tsk_tree_first(&tree); ret == TSK_TREE_OK; ret = tsk_tree_next(&tree)) { + while (next_mut < (tsk_id_t) mutations->num_rows + && sites_position[mutations->site[next_mut]] < tree.interval.right) { + t = mutations->time[next_mut]; + if (tsk_is_unknown_time(t)) { + ret = TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME; + goto out; + } + c = mutations->node[next_mut]; + tsk_bug_assert(c < (tsk_id_t) num_nodes); + p = tree.parent[c]; + while (p != TSK_NULL && nodes_time[p] <= t) { + c = p; + p = tree.parent[c]; + } + tsk_bug_assert(nodes_time[c] <= t); + mutations->node[next_mut] = c; + next_mut++; + } + } if (ret != 0) { goto out; } - ret = tsk_tree_position_init(&tree_pos, self, 0); + +out: + tsk_tree_free(&tree); + + return ret; +} + +typedef struct { + const tsk_treeseq_t *ts; + tsk_edge_table_t *edges; + int direction; + tsk_id_t *last_degree, *next_degree; + tsk_id_t *last_nodes_edge, *next_nodes_edge; + tsk_id_t *parent_out, *parent_in; + bool *not_sample; + double *near_side, *far_side; + edge_list_t *edges_out_head, *edges_out_tail; + edge_list_t *edges_in_head, *edges_in_tail; + tsk_blkalloc_t edge_list_heap; +} haplotype_extender_t; + +static int +haplotype_extender_init(haplotype_extender_t *self, const tsk_treeseq_t *ts, + int direction, tsk_edge_table_t *edges) +{ + int ret = 0; + tsk_id_t tj; + tsk_size_t num_nodes = tsk_treeseq_get_num_nodes(ts); + + tsk_memset(self, 0, sizeof(haplotype_extender_t)); + + self->ts = ts; + self->edges = edges; + ret = tsk_edge_table_copy(&ts->tables->edges, self->edges, TSK_NO_INIT); if (ret != 0) { goto out; } - ret = tsk_edge_table_copy(&tables->edges, edges, TSK_NO_INIT); + + self->direction = direction; + if (direction == TSK_DIR_FORWARD) { + self->near_side = self->edges->left; + self->far_side = self->edges->right; + } else { + self->near_side = self->edges->right; + self->far_side = self->edges->left; + } + + self->edges_in_head = NULL; + self->edges_in_tail = NULL; + self->edges_out_head = NULL; + self->edges_out_tail = NULL; + + ret = tsk_blkalloc_init(&self->edge_list_heap, 8192); if (ret != 0) { goto out; } - for (tj = 0; tj < (tsk_id_t) tables->nodes.num_rows; tj++) { - not_sample[tj] = ((tables->nodes.flags[tj] & TSK_NODE_IS_SAMPLE) == 0); + self->last_degree = tsk_calloc(num_nodes, sizeof(*self->last_degree)); + self->next_degree = tsk_calloc(num_nodes, sizeof(*self->next_degree)); + self->last_nodes_edge = tsk_malloc(num_nodes * sizeof(*self->last_nodes_edge)); + self->next_nodes_edge = tsk_malloc(num_nodes * sizeof(*self->next_nodes_edge)); + self->parent_out = tsk_malloc(num_nodes * sizeof(*self->parent_out)); + self->parent_in = tsk_malloc(num_nodes * sizeof(*self->parent_in)); + self->not_sample = tsk_malloc(num_nodes * sizeof(*self->not_sample)); + + if (self->last_degree == NULL || self->next_degree == NULL + || self->last_nodes_edge == NULL || self->next_nodes_edge == NULL + || self->parent_out == NULL || self->parent_in == NULL + || self->not_sample == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tsk_memset(self->last_nodes_edge, 0xff, num_nodes * sizeof(*self->last_nodes_edge)); + tsk_memset(self->next_nodes_edge, 0xff, num_nodes * sizeof(*self->next_nodes_edge)); + tsk_memset(self->parent_out, 0xff, num_nodes * sizeof(*self->parent_out)); + tsk_memset(self->parent_in, 0xff, num_nodes * sizeof(*self->parent_in)); + + for (tj = 0; tj < (tsk_id_t) num_nodes; tj++) { + self->not_sample[tj] = ((ts->tables->nodes.flags[tj] & TSK_NODE_IS_SAMPLE) == 0); + } + +out: + return ret; +} + +static void +haplotype_extender_print_state(haplotype_extender_t *self, FILE *out) +{ + fprintf(out, "\n======= haplotype extender ===========\n"); + fprintf(out, "parent in:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + fprintf(out, " %d: %d\n", j, (int) self->parent_in[j]); + } + fprintf(out, "parent out:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + fprintf(out, " %d: %d\n", j, (int) self->parent_out[j]); + } + fprintf(out, "last nodes edge:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + tsk_id_t ej = self->last_nodes_edge[j]; + fprintf(out, " %d: %d, ", j, (int) ej); + if (self->last_nodes_edge[j] != TSK_NULL) { + fprintf(out, "(%d->%d, %.1f-%.1f)", (int) self->edges->child[ej], + (int) self->edges->parent[ej], self->edges->left[ej], + self->edges->right[ej]); + } else { + fprintf(out, "(null);"); + } + fprintf(out, "\n"); + } + fprintf(out, "next nodes edge:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + tsk_id_t ej = self->next_nodes_edge[j]; + fprintf(out, " %d: %d, ", j, (int) ej); + if (self->next_nodes_edge[j] != TSK_NULL) { + fprintf(out, "(%d->%d, %.1f-%.1f)", (int) self->edges->child[ej], + (int) self->edges->parent[ej], self->edges->left[ej], + self->edges->right[ej]); + } else { + fprintf(out, "(null);"); + } + fprintf(out, "\n"); + } + fprintf(out, "edges out:\n"); + edge_list_print(&self->edges_out_head, self->edges, out); + fprintf(out, "edges in:\n"); + edge_list_print(&self->edges_in_head, self->edges, out); +} + +static int +haplotype_extender_free(haplotype_extender_t *self) +{ + tsk_blkalloc_free(&self->edge_list_heap); + tsk_safe_free(self->last_degree); + tsk_safe_free(self->next_degree); + tsk_safe_free(self->last_nodes_edge); + tsk_safe_free(self->next_nodes_edge); + tsk_safe_free(self->parent_out); + tsk_safe_free(self->parent_in); + tsk_safe_free(self->not_sample); + return 0; +} + +static int +haplotype_extender_next_tree(haplotype_extender_t *self, tsk_tree_position_t *tree_pos) +{ + int ret = 0; + tsk_id_t tj, e; + edge_list_t *ex_out, *ex_in; + edge_list_t *new_ex; + const tsk_id_t *edges_child = self->edges->child; + const tsk_id_t *edges_parent = self->edges->parent; + + for (ex_out = self->edges_out_head; ex_out != NULL; ex_out = ex_out->next) { + e = ex_out->edge; + self->parent_out[edges_child[e]] = TSK_NULL; + // note we only adjust near_side of edges_in, not edges_out, + // so no need to check for zero-length edges + if (ex_out->extended > 1) { + // this is needed to catch newly-created edges + self->last_nodes_edge[edges_child[e]] = e; + self->last_degree[edges_child[e]] += 1; + self->last_degree[edges_parent[e]] += 1; + } else if (ex_out->extended == 0) { + self->last_nodes_edge[edges_child[e]] = TSK_NULL; + self->last_degree[edges_child[e]] -= 1; + self->last_degree[edges_parent[e]] -= 1; + } + } + remove_unextended(&self->edges_out_head, &self->edges_out_tail); + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e = ex_in->edge; + self->parent_in[edges_child[e]] = TSK_NULL; + if (ex_in->extended == 0 && self->near_side[e] != self->far_side[e]) { + self->last_nodes_edge[edges_child[e]] = e; + self->last_degree[edges_child[e]] += 1; + self->last_degree[edges_parent[e]] += 1; + } + } + remove_unextended(&self->edges_in_head, &self->edges_in_tail); + + // done cleanup from last tree transition; + // now we set the state up for this tree transition + for (tj = tree_pos->out.start; tj != tree_pos->out.stop; tj += self->direction) { + e = tree_pos->out.order[tj]; + if (self->near_side[e] != self->far_side[e]) { + new_ex = tsk_blkalloc_get(&self->edge_list_heap, sizeof(*new_ex)); + if (new_ex == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + edge_list_append_entry( + &self->edges_out_head, &self->edges_out_tail, new_ex, e, 0); + } + } + for (ex_out = self->edges_out_head; ex_out != NULL; ex_out = ex_out->next) { + e = ex_out->edge; + self->parent_out[edges_child[e]] = edges_parent[e]; + self->next_nodes_edge[edges_child[e]] = TSK_NULL; + self->next_degree[edges_child[e]] -= 1; + self->next_degree[edges_parent[e]] -= 1; + } + + for (tj = tree_pos->in.start; tj != tree_pos->in.stop; tj += self->direction) { + e = tree_pos->in.order[tj]; + // add edge to pending_in + new_ex = tsk_blkalloc_get(&self->edge_list_heap, sizeof(*new_ex)); + if (new_ex == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + edge_list_append_entry(&self->edges_in_head, &self->edges_in_tail, new_ex, e, 0); + } + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e = ex_in->edge; + self->parent_in[edges_child[e]] = edges_parent[e]; + self->next_nodes_edge[edges_child[e]] = e; + self->next_degree[edges_child[e]] += 1; + self->next_degree[edges_parent[e]] += 1; } - if (forwards) { - near_side = edges->left; - far_side = edges->right; +out: + return ret; +} + +static int +haplotype_extender_add_or_extend_edge(haplotype_extender_t *self, tsk_id_t new_parent, + tsk_id_t child, double left, double right) +{ + int ret = 0; + double there; + tsk_id_t old_edge, e_out, old_parent; + edge_list_t *ex_in; + edge_list_t *new_ex = NULL; + tsk_id_t e_in; + + there = (self->direction == TSK_DIR_FORWARD) ? right : left; + old_edge = self->next_nodes_edge[child]; + if (old_edge != TSK_NULL) { + old_parent = self->edges->parent[old_edge]; } else { - near_side = edges->right; - far_side = edges->left; + old_parent = TSK_NULL; + } + if (new_parent != old_parent) { + if (self->parent_out[child] == new_parent) { + // if our new edge is in edges_out, it should be extended + e_out = self->last_nodes_edge[child]; + self->far_side[e_out] = there; + edge_list_set_extended(&self->edges_out_head, e_out); + } else { + e_out = tsk_edge_table_add_row( + self->edges, left, right, new_parent, child, NULL, 0); + if (e_out < 0) { + ret = (int) e_out; + goto out; + } + /* pointers to left/right might have changed! */ + if (self->direction == TSK_DIR_FORWARD) { + self->near_side = self->edges->left; + self->far_side = self->edges->right; + } else { + self->near_side = self->edges->right; + self->far_side = self->edges->left; + } + new_ex = tsk_blkalloc_get(&self->edge_list_heap, sizeof(*new_ex)); + if (new_ex == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + edge_list_append_entry( + &self->edges_out_head, &self->edges_out_tail, new_ex, e_out, 2); + } + self->next_nodes_edge[child] = e_out; + self->next_degree[child] += 1; + self->next_degree[new_parent] += 1; + self->parent_out[child] = TSK_NULL; + if (old_edge != TSK_NULL) { + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e_in = ex_in->edge; + if (e_in == old_edge) { + self->near_side[e_in] = there; + if (self->far_side[e_in] != there) { + ex_in->extended = 1; + } + self->next_degree[child] -= 1; + self->next_degree[self->parent_in[child]] -= 1; + self->parent_in[child] = TSK_NULL; + } + } + } + } +out: + return ret; +} + +static float +haplotype_extender_mergeable(haplotype_extender_t *self, tsk_id_t c) +{ + // returns the number of new edges needed + // if the paths in parent_in and parent_out + // up through nodes that aren't in the other tree + // end at the same place and don't have conflicting times; + // otherwise, return infinity + tsk_id_t p_in, p_out, child; + float num_new_edges; // needs to be float so we can have infinity + int num_extended; + double t_in, t_out; + bool climb_in, climb_out; + const double *nodes_time = self->ts->tables->nodes.time; + + p_out = self->parent_out[c]; + p_in = self->parent_in[c]; + t_out = (p_out == TSK_NULL) ? INFINITY : nodes_time[p_out]; + t_in = (p_in == TSK_NULL) ? INFINITY : nodes_time[p_in]; + child = c; + num_new_edges = 0; + num_extended = 0; + while (true) { + climb_in = (p_in != TSK_NULL && self->last_degree[p_in] == 0 + && self->not_sample[p_in] && t_in < t_out); + climb_out = (p_out != TSK_NULL && self->next_degree[p_out] == 0 + && self->not_sample[p_out] && t_out < t_in); + if (climb_in) { + if (self->parent_in[child] != p_in) { + num_new_edges += 1; + } + child = p_in; + p_in = self->parent_in[p_in]; + t_in = (p_in == TSK_NULL) ? INFINITY : nodes_time[p_in]; + } else if (climb_out) { + if (self->parent_out[child] != p_out) { + num_new_edges += 1; + } + child = p_out; + p_out = self->parent_out[p_out]; + t_out = (p_out == TSK_NULL) ? INFINITY : nodes_time[p_out]; + num_extended += 1; + } else { + break; + } } - edges_in_head = NULL; - edges_in_tail = NULL; - edges_out_head = NULL; - edges_out_tail = NULL; - e_out = 0; // only to avoid an 'maybe uninitialized' compile warning + if ((num_extended == 0) || (p_in != p_out) || (p_in == TSK_NULL)) { + num_new_edges = INFINITY; + } + return num_new_edges; +} + +static int +haplotype_extender_merge_paths( + haplotype_extender_t *self, tsk_id_t c, double left, double right) +{ + int ret = 0; + tsk_id_t p_in, p_out, child; + double t_in, t_out; + bool climb_in, climb_out; + const double *nodes_time = self->ts->tables->nodes.time; + + p_out = self->parent_out[c]; + p_in = self->parent_in[c]; + t_out = nodes_time[p_out]; + t_in = nodes_time[p_in]; + child = c; + while (true) { + climb_in = (p_in != TSK_NULL && self->last_degree[p_in] == 0 + && self->not_sample[p_in] && t_in < t_out); + climb_out = (p_out != TSK_NULL && self->next_degree[p_out] == 0 + && self->not_sample[p_out] && t_out < t_in); + if (climb_in) { + ret = haplotype_extender_add_or_extend_edge(self, p_in, child, left, right); + if (ret != 0) { + goto out; + } + child = p_in; + p_in = self->parent_in[p_in]; + t_in = (p_in == TSK_NULL) ? INFINITY : nodes_time[p_in]; + } else if (climb_out) { + ret = haplotype_extender_add_or_extend_edge(self, p_out, child, left, right); + if (ret != 0) { + goto out; + } + child = p_out; + p_out = self->parent_out[p_out]; + t_out = (p_out == TSK_NULL) ? INFINITY : nodes_time[p_out]; + } else { + break; + } + } + tsk_bug_assert(p_out == p_in); + ret = haplotype_extender_add_or_extend_edge(self, p_out, child, left, right); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +static int +haplotype_extender_extend_paths(haplotype_extender_t *self) +{ + int ret = 0; + bool valid; + double left, right; + float ne, max_new_edges, next_max_new_edges; + tsk_tree_position_t tree_pos; + edge_list_t *ex_in; + tsk_id_t e_in, c, e; + tsk_size_t num_edges; + tsk_bool_t *keep = NULL; - if (forwards) { + tsk_memset(&tree_pos, 0, sizeof(tree_pos)); + ret = tsk_tree_position_init(&tree_pos, self->ts, 0); + if (ret != 0) { + goto out; + } + + if (self->direction == TSK_DIR_FORWARD) { valid = tsk_tree_position_next(&tree_pos); } else { valid = tsk_tree_position_prev(&tree_pos); @@ -8662,177 +9076,100 @@ tsk_treeseq_extend_edges_iter( while (valid) { left = tree_pos.interval.left; right = tree_pos.interval.right; - there = forwards ? right : left; - - // remove entries that aren't being extended/postponed - // and update out_parent - for (ex_out = edges_out_head; ex_out != NULL; ex_out = ex_out->next) { - e = ex_out->edge; - out_parent[edges->child[e]] = TSK_NULL; - } - remove_unextended(&edges_in_head, &edges_in_tail); - remove_unextended(&edges_out_head, &edges_out_tail); - for (ex_out = edges_out_head; ex_out != NULL; ex_out = ex_out->next) { - e = ex_out->edge; - out_parent[edges->child[e]] = edges->parent[e]; - } - - for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += direction) { - e = tree_pos.out.order[tj]; - if (out_parent[edges->child[e]] == TSK_NULL) { - // add edge to pending_out - ret = extend_edges_append_entry( - &edges_out_head, &edges_out_tail, &edge_list_heap, e); - if (ret != 0) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - out_parent[edges->child[e]] = edges->parent[e]; - } - } - for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += direction) { - e = tree_pos.in.order[tj]; - // add edge to pending_in - ret = extend_edges_append_entry( - &edges_in_head, &edges_in_tail, &edge_list_heap, e); - if (ret != 0) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } + ret = haplotype_extender_next_tree(self, &tree_pos); + if (ret != 0) { + goto out; } - for (ex_out = edges_out_head; ex_out != NULL; ex_out = ex_out->next) { - e_out = ex_out->edge; - degree[edges->parent[e_out]] -= 1; - degree[edges->child[e_out]] -= 1; - tsk_bug_assert(out_parent[edges->child[e_out]] == edges->parent[e_out]); - } - for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { - e_in = ex_in->edge; - degree[edges->parent[e_in]] += 1; - degree[edges->child[e_in]] += 1; - } - - for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { - e_in = ex_in->edge; - // check whether the parent-child relationship exists in the - // sub-forest of edges to be removed: - // out_parent[p] != -1 only when it is the bottom of an edge to be - // removed, and degree[p] == 0 only if it is not in the new tree - c = edges->child[e_in]; - p = out_parent[c]; - p_in = edges->parent[e_in]; - while ((p != TSK_NULL) && (degree[p] == 0) && (p != p_in) && not_sample[p]) { - p = out_parent[p]; - } - if (p == p_in) { - // we can extend! - // But, we might have passed the interval that a - // postponed edge in covers, in which case - // we should skip postponing the edge in - if (far_side[e_in] != there) { - ex_in->extended = true; - } - near_side[e_in] = there; - while (c != p) { - for (ex_out = edges_out_head; ex_out != NULL; - ex_out = ex_out->next) { - e_out = ex_out->edge; - if (edges->child[e_out] == c) { - break; + max_new_edges = 0; + next_max_new_edges = INFINITY; + while (max_new_edges < INFINITY) { + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e_in = ex_in->edge; + c = self->edges->child[e_in]; + if (self->last_degree[c] > 0) { + ne = haplotype_extender_mergeable(self, c); + if (ne <= max_new_edges) { + ret = haplotype_extender_merge_paths(self, c, left, right); + if (ret != 0) { + goto out; } + } else { + next_max_new_edges = TSK_MIN(ne, next_max_new_edges); } - tsk_bug_assert(edges->child[e_out] == c); - ex_out->extended = true; - far_side[e_out] = there; - // amend degree: the intermediate - // nodes have 2 edges instead of 0 - tsk_bug_assert(degree[c] == 0 || c == edges->child[e_in]); - if (degree[c] == 0) { - degree[c] = 2; - } - c = out_parent[c]; } } + max_new_edges = next_max_new_edges; + next_max_new_edges = INFINITY; } - if (forwards) { + if (self->direction == TSK_DIR_FORWARD) { valid = tsk_tree_position_next(&tree_pos); } else { valid = tsk_tree_position_prev(&tree_pos); } } - + /* Get rid of adjacent, identical edges */ + /* note: we need to calloc this here instead of at the start + * because we don't know how big it will need to be until now */ + num_edges = self->edges->num_rows; + keep = tsk_calloc(num_edges, sizeof(*keep)); + if (keep == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + for (e = 0; e < (tsk_id_t) num_edges - 1; e++) { + if (self->edges->parent[e] == self->edges->parent[e + 1] + && self->edges->child[e] == self->edges->child[e + 1] + && self->edges->right[e] == self->edges->left[e + 1]) { + self->edges->right[e] = self->edges->right[e + 1]; + self->edges->left[e + 1] = self->edges->right[e + 1]; + } + } for (e = 0; e < (tsk_id_t) num_edges; e++) { - keep[e] = edges->left[e] < edges->right[e]; + keep[e] = self->edges->left[e] < self->edges->right[e]; } - ret = tsk_edge_table_keep_rows(edges, keep, 0, NULL); + ret = tsk_edge_table_keep_rows(self->edges, keep, 0, NULL); out: - tsk_blkalloc_free(&edge_list_heap); tsk_tree_position_free(&tree_pos); - tsk_safe_free(degree); - tsk_safe_free(out_parent); tsk_safe_free(keep); - tsk_safe_free(not_sample); return ret; } static int -tsk_treeseq_slide_mutation_nodes_up( - const tsk_treeseq_t *self, tsk_mutation_table_t *mutations) +extend_haplotypes_iter(const tsk_treeseq_t *self, int direction, tsk_edge_table_t *edges, + tsk_flags_t options) { int ret = 0; - double t; - tsk_id_t c, p, next_mut; - const tsk_table_collection_t *tables = self->tables; - const tsk_size_t num_nodes = tables->nodes.num_rows; - double *sites_position = tables->sites.position; - double *nodes_time = tables->nodes.time; - tsk_tree_t tree; - - ret = tsk_tree_init(&tree, self, TSK_NO_SAMPLE_COUNTS); + haplotype_extender_t haplotype_extender; + tsk_memset(&haplotype_extender, 0, sizeof(haplotype_extender)); + ret = haplotype_extender_init(&haplotype_extender, self, direction, edges); if (ret != 0) { goto out; } - next_mut = 0; - for (ret = tsk_tree_first(&tree); ret == TSK_TREE_OK; ret = tsk_tree_next(&tree)) { - while (next_mut < (tsk_id_t) mutations->num_rows - && sites_position[mutations->site[next_mut]] < tree.interval.right) { - t = mutations->time[next_mut]; - if (tsk_is_unknown_time(t)) { - ret = TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME; - goto out; - } - c = mutations->node[next_mut]; - tsk_bug_assert(c < (tsk_id_t) num_nodes); - p = tree.parent[c]; - while (p != TSK_NULL && nodes_time[p] <= t) { - c = p; - p = tree.parent[c]; - } - tsk_bug_assert(nodes_time[c] <= t); - mutations->node[next_mut] = c; - next_mut++; - } - } + ret = haplotype_extender_extend_paths(&haplotype_extender); if (ret != 0) { goto out; } -out: - tsk_tree_free(&tree); + if (!!(options & TSK_DEBUG)) { + haplotype_extender_print_state(&haplotype_extender, tsk_get_debug_stream()); + } +out: + haplotype_extender_free(&haplotype_extender); return ret; } int TSK_WARN_UNUSED -tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, - tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) +tsk_treeseq_extend_haplotypes( + const tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output) { int ret = 0; tsk_table_collection_t tables; tsk_treeseq_t ts; int iter, j; tsk_size_t last_num_edges; + tsk_bookmark_t sort_start; const int direction[] = { TSK_DIR_FORWARD, TSK_DIR_REVERSE }; tsk_memset(&tables, 0, sizeof(tables)); @@ -8872,12 +9209,20 @@ tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, last_num_edges = tsk_treeseq_get_num_edges(&ts); for (iter = 0; iter < max_iter; iter++) { for (j = 0; j < 2; j++) { - ret = tsk_treeseq_extend_edges_iter(&ts, direction[j], &tables.edges); + ret = extend_haplotypes_iter(&ts, direction[j], &tables.edges, options); if (ret != 0) { goto out; } /* We're done with the current ts now */ tsk_treeseq_free(&ts); + /* no need to sort sites and mutations */ + memset(&sort_start, 0, sizeof(sort_start)); + sort_start.sites = tables.sites.num_rows; + sort_start.mutations = tables.mutations.num_rows; + ret = tsk_table_collection_sort(&tables, &sort_start, 0); + if (ret != 0) { + goto out; + } ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); if (ret != 0) { goto out; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 667848415b..7169851cdc 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -917,18 +917,18 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, tsk_id_t *node_map); /** -@brief Extends edges +@brief Extends haplotypes -Returns a modified tree sequence in which the span covered by ancestral nodes +Returns a new tree sequence in which the span covered by ancestral nodes is "extended" to regions of the genome according to the following rule: -If an ancestral segment corresponding to node `n` has parent `p` and -child `c` on some portion of the genome, and on an adjacent segment of -genome `p` is the immediate parent of `c`, then `n` is inserted into the -edge from `p` to `c`. This involves extending the span of the edges -from `p` to `n` and `n` to `c` and reducing the span of the edge from -`p` to `c`. Since the latter edge may be removed entirely, this process -reduces (or at least does not increase) the number of edges in the tree -sequence. The `node` of certain mutations may also be remapped; to do this +If an ancestral segment corresponding to node `n` has ancestor `p` and +descendant `c` on some portion of the genome, and on an adjacent segment of +genome `p` is still an ancestor of `c`, then `n` is inserted into the +path from `p` to `c`. For instance, if `p` is the parent of `n` and `n` +is the parent of `c`, then the span of the edges from `p` to `n` and +`n` to `c` are extended, and the span of the edge from `p` to `c` is +reduced. However, any edges whose child node is a sample are not +modified. The `node` of certain mutations may also be remapped; to do this unambiguously we need to know mutation times. If mutations times are unknown, use `tsk_table_collection_compute_mutation_times` first. @@ -939,8 +939,6 @@ The method works by iterating over the genome to look for edges that can be extended in this way; the maximum number of such iterations is controlled by ``max_iter``. -Since this may change which nodes are above - @rst **Options**: None currently defined. @@ -952,7 +950,7 @@ Since this may change which nodes are above @param output A pointer to an uninitialised tsk_treeseq_t object. @return Return 0 on success or a negative value on failure. */ -int tsk_treeseq_extend_edges( +int tsk_treeseq_extend_haplotypes( const tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output); /** @} */ diff --git a/docs/python-api.md b/docs/python-api.md index 20a2d541b4..12756ff3e3 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -268,7 +268,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.trim TreeSequence.split_edges TreeSequence.decapitate - TreeSequence.extend_edges + TreeSequence.extend_haplotypes ``` (sec_python_api_tree_sequences_ibd)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index f3bc0e1629..6cfc075bbf 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -26,9 +26,10 @@ **Features** -- Add ``TreeSequence.extend_edges`` method that extends ancestral haplotypes +- Add ``TreeSequence.extend_haplotypes`` method that extends ancestral haplotypes using recombination information, leading to unary nodes in many trees and - fewer edges. (:user:`petrelharp`, :user:`hfr1tz3`, :user:`avabamf`, :pr:`2651`) + fewer edges. (:user:`petrelharp`, :user:`hfr1tz3`, :user: `nspope`, + :user:`avabamf`, :pr:`2651`, :pr:`2938`) - Add ``Table.drop_metadata`` to make clearing metadata from tables easy. (:user:`jeromekelleher`, :pr:`2944`) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index a37d8b0160..1c7ee57b0a 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8944,7 +8944,7 @@ TreeSequence_mean_descendants(TreeSequence *self, PyObject *args, PyObject *kwds } static PyObject * -TreeSequence_extend_edges(TreeSequence *self, PyObject *args, PyObject *kwds) +TreeSequence_extend_haplotypes(TreeSequence *self, PyObject *args, PyObject *kwds) { int err; PyObject *ret = NULL; @@ -8970,7 +8970,7 @@ TreeSequence_extend_edges(TreeSequence *self, PyObject *args, PyObject *kwds) goto out; } - err = tsk_treeseq_extend_edges( + err = tsk_treeseq_extend_haplotypes( self->tree_sequence, max_iter, options, output->tree_sequence); if (err != 0) { handle_library_error(err); @@ -11200,10 +11200,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_split_edges, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns a copy of this tree sequence edges split at time t" }, - { .ml_name = "extend_edges", - .ml_meth = (PyCFunction) TreeSequence_extend_edges, + { .ml_name = "extend_haplotypes", + .ml_meth = (PyCFunction) TreeSequence_extend_haplotypes, .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc = "Extends edges, creating unary nodes." }, + .ml_doc = "Extends ancestral haplotypes, creating unary nodes." }, { .ml_name = "has_reference_sequence", .ml_meth = (PyCFunction) TreeSequence_has_reference_sequence, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_extend_edges.py b/python/tests/test_extend_edges.py deleted file mode 100644 index 7ee8c8f471..0000000000 --- a/python/tests/test_extend_edges.py +++ /dev/null @@ -1,654 +0,0 @@ -import msprime -import numpy as np -import pytest - -import _tskit -import tests.test_wright_fisher as wf -import tskit -from tests import tsutil -from tests.test_highlevel import get_example_tree_sequences - -# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when -# we can remove this. - - -def extend_edges(ts, max_iter=10): - tables = ts.dump_tables() - mutations = tables.mutations.copy() - tables.mutations.clear() - - last_num_edges = ts.num_edges - for _ in range(max_iter): - for forwards in [True, False]: - edges = _extend(ts, forwards=forwards) - tables.edges.replace_with(edges) - tables.build_index() - ts = tables.tree_sequence() - if ts.num_edges == last_num_edges: - break - else: - last_num_edges = ts.num_edges - - tables = ts.dump_tables() - mutations = _slide_mutation_nodes_up(ts, mutations) - tables.mutations.replace_with(mutations) - ts = tables.tree_sequence() - - return ts - - -def _slide_mutation_nodes_up(ts, mutations): - # adjusts mutations' nodes to place each mutation on the correct edge given - # their time; requires mutation times be nonmissing and the mutation times - # be >= their nodes' times. - - assert np.all(~tskit.is_unknown_time(mutations.time)), "times must be known" - new_nodes = mutations.node.copy() - - mut = 0 - for tree in ts.trees(): - _, right = tree.interval - while ( - mut < mutations.num_rows and ts.sites_position[mutations.site[mut]] < right - ): - t = mutations.time[mut] - c = mutations.node[mut] - p = tree.parent(c) - assert ts.nodes_time[c] <= t - while p != -1 and ts.nodes_time[p] <= t: - c = p - p = tree.parent(c) - assert ts.nodes_time[c] <= t - if p != -1: - assert t < ts.nodes_time[p] - new_nodes[mut] = c - mut += 1 - - # in C the node column can be edited in place - new_mutations = mutations.copy() - new_mutations.clear() - for mut, n in zip(mutations, new_nodes): - new_mutations.append(mut.replace(node=n)) - - return new_mutations - - -def _extend(ts, forwards=True): - # `degree` will record the degree of each node in the tree we'd get if - # we removed all `out` edges and added all `in` edges - degree = np.full(ts.num_nodes, 0, dtype="int") - # `out_parent` will record the sub-forest of edges-to-be-removed - out_parent = np.full(ts.num_nodes, -1, dtype="int") - keep = np.full(ts.num_edges, True, dtype=bool) - not_sample = [not n.is_sample() for n in ts.nodes()] - - edges = ts.tables.edges.copy() - - # "here" will be left if fowards else right; - # and "there" is the other - new_left = edges.left.copy() - new_right = edges.right.copy() - if forwards: - direction = 1 - # in C we can just modify these in place, but in - # python they are (silently) immutable - near_side = new_left - far_side = new_right - else: - direction = -1 - near_side = new_right - far_side = new_left - edges_out = [] - edges_in = [] - - tree_pos = tsutil.TreePosition(ts) - if forwards: - valid = tree_pos.next() - else: - valid = tree_pos.prev() - while valid: - left, right = tree_pos.interval - there = right if forwards else left - - # Clear out non-extended or postponed edges: - # Note: maintaining out_parent is a bit tricky, because - # if an edge from p->c has been extended, entirely replacing - # another edge from p'->c, then both edges may be in edges_out, - # and we only want to include the *first* one. - for e, _ in edges_out: - out_parent[edges.child[e]] = -1 - tmp = [] - for e, x in edges_out: - if x: - tmp.append([e, False]) - edges_out = tmp - tmp = [] - for e, x in edges_in: - if x: - tmp.append([e, False]) - edges_in = tmp - - for e, _ in edges_out: - out_parent[edges.child[e]] = edges.parent[e] - - for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, direction): - e = tree_pos.out_range.order[j] - if out_parent[edges.child[e]] == -1: - edges_out.append([e, False]) - out_parent[edges.child[e]] = edges.parent[e] - - for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, direction): - e = tree_pos.in_range.order[j] - edges_in.append([e, False]) - - for e, _ in edges_out: - degree[edges.parent[e]] -= 1 - degree[edges.child[e]] -= 1 - for e, _ in edges_in: - degree[edges.parent[e]] += 1 - degree[edges.child[e]] += 1 - - # validate out_parent array - for c, p in enumerate(out_parent): - foundit = False - for e, _ in edges_out: - if edges.child[e] == c: - assert edges.parent[e] == p - foundit = True - break - assert foundit == (p != -1) - - assert np.all(degree >= 0) - for ex_in in edges_in: - e_in = ex_in[0] - # check whether the parent-child relationship exists - # in the sub-forest of edges to be removed: - # out_parent[p] != -1 only when it is the bottom of - # an edge to be removed, - # and degree[p] == 0 only if it is not in the new tree - c = edges.child[e_in] - p = out_parent[c] - p_in = edges.parent[e_in] - while p != tskit.NULL and degree[p] == 0 and p != p_in and not_sample[p]: - p = out_parent[p] - if p == p_in: - # we might have passed the interval that a - # postponed edge in covers, in which case - # we should skip it - if far_side[e_in] != there: - ex_in[1] = True - near_side[e_in] = there - while c != p: - # just loop over the edges out until we find the right entry - for ex_out in edges_out: - e_out = ex_out[0] - if edges.child[e_out] == c: - break - assert edges.child[e_out] == c - ex_out[1] = True - far_side[e_out] = there - # amend degree: the intermediate - # nodes have 2 edges instead of 0 - assert degree[c] == 0 or c == edges.child[e_in] - if degree[c] == 0: - degree[c] = 2 - c = out_parent[c] - - # end of loop, next tree - if forwards: - valid = tree_pos.next() - else: - valid = tree_pos.prev() - - for j in range(edges.num_rows): - left = new_left[j] - right = new_right[j] - if left < right: - edges[j] = edges[j].replace(left=left, right=right) - else: - keep[j] = False - edges.keep_rows(keep) - return edges - - -class TestExtendEdges: - """ - Test the 'extend edges' method - """ - - def verify_extend_edges(self, ts, max_iter=10): - # This can still fail for various weird examples: - # for instance, if adjacent trees have - # a <- b <- c <- d and a <- d (where say b was - # inserted in an earlier pass), then b and c - # won't be extended - - ets = ts.extend_edges(max_iter=max_iter) - assert np.all(ts.genotype_matrix() == ets.genotype_matrix()) - assert ts.num_samples == ets.num_samples - assert ts.num_nodes == ets.num_nodes - assert ts.num_edges >= ets.num_edges - t = ts.simplify().tables - et = ets.simplify().tables - t.assert_equals(et, ignore_provenance=True) - old_edges = {} - for e in ts.edges(): - k = (e.parent, e.child) - if k not in old_edges: - old_edges[k] = [] - old_edges[k].append((e.left, e.right)) - - for e in ets.edges(): - # e should be in old_edges, - # but with modified limits: - # USUALLY overlapping limits, but - # not necessarily after more than one pass - k = (e.parent, e.child) - assert k in old_edges - if max_iter == 1: - overlaps = False - for left, right in old_edges[k]: - if (left <= e.right) and (right >= e.left): - overlaps = True - assert overlaps - - if max_iter > 1: - chains = [] - for _, tt, ett in ts.coiterate(ets): - this_chains = [] - for a in tt.nodes(): - assert a in ett.nodes() - b = tt.parent(a) - if b != tskit.NULL: - c = tt.parent(b) - if c != tskit.NULL: - this_chains.append((a, b, c)) - assert b in ett.nodes() - # the relationship a <- b should still be in the tree - p = a - while p != tskit.NULL and p != b: - p = ett.parent(p) - assert p == b - chains.append(this_chains) - - extended_ac = {} - not_extended_ac = {} - extended_ab = {} - not_extended_ab = {} - for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): - for j in (k - 1, k + 1): - if j < 0 or j >= len(chains): - continue - else: - this_chains = chains[j] - for a, b, c in this_chains: - if ( - a in tt.nodes() - and tt.parent(a) == c - and b not in tt.nodes() - ): - # the relationship a <- b <- c should still be in the tree, - # although maybe they aren't direct parent-offspring - # UNLESS we've got an ambiguous case, where on the opposite - # side of the interval a chain a <- b' <- c got extended - # into the region OR b got inserted into another chain - assert a in ett.nodes() - assert c in ett.nodes() - if b not in ett.nodes(): - if (a, c) not in not_extended_ac: - not_extended_ac[(a, c)] = [] - not_extended_ac[(a, c)].append(interval) - else: - if (a, c) not in extended_ac: - extended_ac[(a, c)] = [] - extended_ac[(a, c)].append(interval) - p = a - while p != tskit.NULL and p != b: - p = ett.parent(p) - if p != b: - if (a, b) not in not_extended_ab: - not_extended_ab[(a, b)] = [] - not_extended_ab[(a, b)].append(interval) - else: - if (a, b) not in extended_ab: - extended_ab[(a, b)] = [] - extended_ab[(a, b)].append(interval) - while p != tskit.NULL and p != c: - p = ett.parent(p) - assert p == c - for a, c in not_extended_ac: - # check that a <- ... <- c has been extended somewhere - # although not necessarily from an adjacent segment - assert (a, c) in extended_ac - for interval in not_extended_ac[(a, c)]: - ett = ets.at(interval.left) - assert ett.parent(a) != c - for k in not_extended_ab: - assert k in extended_ab - for interval in not_extended_ab[k]: - assert interval in extended_ab[k] - - # finally, compare C version to python version - py_ts = extend_edges(ts, max_iter=max_iter) - py_et = py_ts.dump_tables() - et = ets.dump_tables() - et.assert_equals(py_et) - - def test_runs(self): - ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) - self.verify_extend_edges(ts) - - def test_migrations_disallowed(self): - ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) - tables = ts.dump_tables() - tables.populations.add_row() - tables.populations.add_row() - tables.migrations.add_row(0, 1, 0, 0, 1, 0) - ts = tables.tree_sequence() - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" - ): - _ = ts.extend_edges() - - def test_unknown_times(self): - ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) - tables = ts.dump_tables() - tables.mutations.clear() - for mut in ts.mutations(): - tables.mutations.append(mut.replace(time=tskit.UNKNOWN_TIME)) - ts = tables.tree_sequence() - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME" - ): - _ = ts.extend_edges() - - def test_max_iter(self): - ts = msprime.simulate(5, random_seed=126) - with pytest.raises(_tskit.LibraryError, match="positive"): - ets = ts.extend_edges(max_iter=0) - with pytest.raises(_tskit.LibraryError, match="positive"): - ets = ts.extend_edges(max_iter=-1) - ets = ts.extend_edges(max_iter=1) - et = ets.extend_edges(max_iter=1).dump_tables() - eet = ets.extend_edges(max_iter=2).dump_tables() - eet.assert_equals(et) - - def get_simple_ex(self, samples=None): - # An example where you need to go forwards *and* backwards: - # 7 and 8 should be extended to the whole sequence - # - # 6 6 6 6 - # +-+-+ +-+-+ +-+-+ +-+-+ - # | | 7 | | 8 | | - # | | ++-+ | | +-++ | | - # 4 5 4 | | 4 | 5 4 5 - # +++ +++ +++ | | | | +++ +++ +++ - # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - # - # Result: - # - # 6 6 6 6 - # +-+-+ +-+-+ +-+-+ +-+-+ - # 7 8 7 8 7 8 7 8 - # | | ++-+ | | +-++ | | - # 4 5 4 | 5 4 | 5 4 5 - # +++ +++ +++ | | | | +++ +++ +++ - # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - - node_times = { - 0: 0, - 1: 0, - 2: 0, - 3: 0, - 4: 1.0, - 5: 1.0, - 6: 3.0, - 7: 2.0, - 8: 2.0, - } - # (p, c, l, r) - edges = [ - (4, 0, 0, 10), - (4, 1, 0, 5), - (4, 1, 7, 10), - (5, 2, 0, 2), - (5, 2, 5, 10), - (5, 3, 0, 2), - (5, 3, 5, 10), - (7, 2, 2, 5), - (7, 4, 2, 5), - (8, 1, 5, 7), - (8, 5, 5, 7), - (6, 3, 2, 5), - (6, 4, 0, 2), - (6, 4, 5, 10), - (6, 5, 0, 2), - (6, 5, 7, 10), - (6, 7, 2, 5), - (6, 8, 5, 7), - ] - # here is the 'right answer' (but note only with the default args) - extended_edges = [ - (4, 0, 0, 10), - (4, 1, 0, 5), - (4, 1, 7, 10), - (5, 2, 0, 2), - (5, 2, 5, 10), - (5, 3, 0, 10), - (7, 2, 2, 5), - (7, 4, 0, 10), - (8, 1, 5, 7), - (8, 5, 0, 10), - (6, 7, 0, 10), - (6, 8, 0, 10), - ] - tables = tskit.TableCollection(sequence_length=10) - if samples is None: - samples = [0, 1, 2, 3] - for n, t in node_times.items(): - flags = tskit.NODE_IS_SAMPLE if n in samples else 0 - tables.nodes.add_row(time=t, flags=flags) - for p, c, l, r in edges: - tables.edges.add_row(parent=p, child=c, left=l, right=r) - ts = tables.tree_sequence() - tables.edges.clear() - for p, c, l, r in extended_edges: - tables.edges.add_row(parent=p, child=c, left=l, right=r) - ets = tables.tree_sequence() - assert ts.num_edges == 18 - assert ets.num_edges == 12 - return ts, ets - - def test_simple_ex(self): - ts, right_ets = self.get_simple_ex() - ets = ts.extend_edges() - ets.tables.assert_equals(right_ets.tables) - self.verify_extend_edges(ts) - - def test_internal_samples(self): - # Now we should have the same but not extend 5 (where * is): - # - # 6 6 6 6 - # +-+-+ +-+-+ +-+-+ +-+-+ - # 7 * 7 * 7 8 7 8 - # | | ++-+ | | +-++ | | - # 4 5 4 | * 4 | 5 4 5 - # +++ +++ +++ | | | | +++ +++ +++ - # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - # - # (p, c, l, r) - edges = [ - (4, 0, 0, 10), - (4, 1, 0, 5), - (4, 1, 7, 10), - (5, 2, 0, 2), - (5, 2, 5, 10), - (5, 3, 0, 2), - (5, 3, 5, 10), - (7, 2, 2, 5), - (7, 4, 0, 10), - (8, 1, 5, 7), - (8, 5, 5, 10), - (6, 3, 2, 5), - (6, 5, 0, 2), - (6, 7, 0, 10), - (6, 8, 5, 10), - ] - ts, _ = self.get_simple_ex(samples=[0, 1, 2, 3, 5]) - tables = ts.dump_tables() - tables.edges.clear() - for p, c, l, r in edges: - tables.edges.add_row(parent=p, child=c, left=l, right=r) - ets = ts.extend_edges() - ets.tables.assert_equals(tables) - # validation doesn't work with internal, incomplete samples - # (and it would be a pain to make it work) - # self.verify_extend_edges(ts) - - def test_iterative_example(self): - # Here is the full tree; extend edges should be able to - # recover all unary nodes after simplification: - # - # 9 9 9 9 - # +-+-+ +--+--+ +---+---+ +-+-+--+ - # 8 | 8 | 8 | | 8 | | | - # | | +-+-+ | | | | | | | | - # 7 | | 7 | | 7 | | | | 7 - # +-+-+ | | +-++ | | +-++ | | | | | - # 6 | | | | 6 | | | 6 | | | | 6 - # +-++ | | | | | | | | | | | | | | - # 1 0 2 3 1 2 0 3 1 2 0 3 1 2 3 0 - # +++ +++ +++ +++ - # 4 5 4 5 4 5 4 5 - # - samples = [0, 1, 2, 3, 4, 5] - node_times = [1, 1, 1, 1, 0, 0, 2, 3, 4, 5] - # (p, c, l, r) - edges = [ - (0, 4, 0, 10), - (0, 5, 0, 10), - (6, 0, 0, 10), - (6, 1, 0, 3), - (7, 2, 0, 7), - (7, 6, 0, 10), - (8, 1, 3, 10), - (8, 7, 0, 5), - (9, 2, 7, 10), - (9, 3, 0, 10), - (9, 7, 5, 10), - (9, 8, 0, 10), - ] - tables = tskit.TableCollection(sequence_length=10) - for n, t in enumerate(node_times): - flags = tskit.NODE_IS_SAMPLE if n in samples else 0 - tables.nodes.add_row(time=t, flags=flags) - for p, c, l, r in edges: - tables.edges.add_row(parent=p, child=c, left=l, right=r) - ts = tables.tree_sequence() - sts = ts.simplify() - assert ts.num_edges == 12 - assert sts.num_edges == 16 - tables.assert_equals(sts.extend_edges().tables, ignore_provenance=True) - - def test_very_simple(self): - samples = [0] - node_times = [0, 1, 2, 3] - # (p, c, l, r) - edges = [ - (1, 0, 0, 1), - (2, 0, 1, 2), - (2, 1, 0, 1), - (3, 0, 2, 3), - (3, 2, 0, 2), - ] - correct_edges = [ - (1, 0, 0, 3), - (2, 1, 0, 3), - (3, 2, 0, 3), - ] - tables = tskit.TableCollection(sequence_length=3) - for n, t in enumerate(node_times): - flags = tskit.NODE_IS_SAMPLE if n in samples else 0 - tables.nodes.add_row(time=t, flags=flags) - for p, c, l, r in edges: - tables.edges.add_row(parent=p, child=c, left=l, right=r) - ts = tables.tree_sequence() - ets = ts.extend_edges() - for _, t, et in ts.coiterate(ets): - print("----") - print(t.draw(format="ascii")) - print(et.draw(format="ascii")) - etables = ets.tables - correct_tables = etables.copy() - etables.edges.clear() - for p, c, l, r in correct_edges: - etables.edges.add_row(parent=p, child=c, left=l, right=r) - etables.assert_equals(correct_tables, ignore_provenance=True) - - def test_wright_fisher(self): - tables = wf.wf_sim(N=5, ngens=20, num_loci=100, deep_history=False, seed=3) - tables.sort() - tables.simplify() - ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888) - self.verify_extend_edges(ts, max_iter=1) - self.verify_extend_edges(ts) - - def test_wright_fisher_unsimplified(self): - tables = wf.wf_sim(N=6, ngens=22, num_loci=100, deep_history=False, seed=4) - tables.sort() - ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888) - self.verify_extend_edges(ts, max_iter=1) - self.verify_extend_edges(ts) - - def test_wright_fisher_with_history(self): - tables = wf.wf_sim(N=8, ngens=15, num_loci=100, deep_history=True, seed=5) - tables.sort() - tables.simplify() - ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888) - self.verify_extend_edges(ts, max_iter=1) - self.verify_extend_edges(ts) - - # This one fails sometimes but just because our verification can't handle - # figuring out what exactly should be the right answer in complex cases. - # - # def test_bigger_wright_fisher(self): - # tables = wf.wf_sim(N=50, ngens=15, deep_history=True, seed=6) - # tables.sort() - # tables.simplify() - # ts = tables.tree_sequence() - # self.verify_extend_edges(ts, max_iter=1) - # self.verify_extend_edges(ts, max_iter=200) - - -class TestExamples: - """ - Compare the ts method with local implementation. - """ - - def check(self, ts): - if np.any(tskit.is_unknown_time(ts.mutations_time)): - tables = ts.dump_tables() - tables.compute_mutation_times() - ts = tables.tree_sequence() - py_ts = extend_edges(ts) - lib_ts = ts.extend_edges() - lib_ts.tables.assert_equals(py_ts.tables) - assert np.all(ts.genotype_matrix() == lib_ts.genotype_matrix()) - sts = ts.simplify() - lib_sts = lib_ts.simplify() - lib_sts.tables.assert_equals(sts.tables, ignore_provenance=True) - - @pytest.mark.parametrize("ts", get_example_tree_sequences()) - def test_suite_examples_defaults(self, ts): - if ts.num_migrations == 0: - self.check(ts) - else: - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" - ): - _ = ts.extend_edges() - - @pytest.mark.parametrize("n", [3, 4, 5]) - def test_all_trees_ts(self, n): - ts = tsutil.all_trees_ts(n) - self.check(ts) diff --git a/python/tests/test_extend_haplotypes.py b/python/tests/test_extend_haplotypes.py new file mode 100644 index 0000000000..0d4e40a7f8 --- /dev/null +++ b/python/tests/test_extend_haplotypes.py @@ -0,0 +1,1280 @@ +import msprime +import numpy as np +import pytest + +import _tskit +import tests.test_wright_fisher as wf +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +def _slide_mutation_nodes_up(ts, mutations): + # adjusts mutations' nodes to place each mutation on the correct edge given + # their time; requires mutation times be nonmissing and the mutation times + # be >= their nodes' times. + + assert np.all(~tskit.is_unknown_time(mutations.time)), "times must be known" + new_nodes = mutations.node.copy() + + mut = 0 + for tree in ts.trees(): + _, right = tree.interval + while ( + mut < mutations.num_rows and ts.sites_position[mutations.site[mut]] < right + ): + t = mutations.time[mut] + c = mutations.node[mut] + p = tree.parent(c) + assert ts.nodes_time[c] <= t + while p != -1 and ts.nodes_time[p] <= t: + c = p + p = tree.parent(c) + assert ts.nodes_time[c] <= t + if p != -1: + assert t < ts.nodes_time[p] + new_nodes[mut] = c + mut += 1 + + # in C the node column can be edited in place + new_mutations = mutations.copy() + new_mutations.clear() + for mut, n in zip(mutations, new_nodes): + new_mutations.append(mut.replace(node=n)) + + return new_mutations + + +def print_edge_list(head, edges, left, right): + print("Edge list:") + for j, (e, x) in enumerate(head): + print( + f" {j}: {e} ({x}); " + + ( + f"{edges.child[e]}->{edges.parent[e]} on [{left[e]}, {right[e]})" + if e >= 0 + else "(null)" + ) + ) + print(f"length = {len(head)}") + + +class HaplotypeExtender: + def __init__(self, ts, forwards): + """ + Below we will iterate through the trees, either to the left or the right, + keeping the following state consistent: + - we are moving from a previous tree, last_tree, to new one, next_tree + - here: the position that separates the last_tree from the next_tree + - (here, there): the segment covered by next_tree + - edges_out: edges to be removed from last_tree to get next_tree + - parent_out: the forest induced by edges_out, a subset of last_tree + - edges_in: edges to be added to last_tree to get next_tree + - parent_in: the forest induced by edges_in, a subset of next_tree + - next_degree: the degree of each node in next_tree + - next_nodes_edge: for each node, the edge above it in next_tree + - last_degree: the degree of each node in last_tree + - last_nodes_edge: for each node, the edge above it in last_tree + Except: each of edges_in and edges_out is of the form e, x, and the + label x>0 if the edge is postponed to the next segment. + The label is x=1 for postponed edges, and x=2 for new edges. + In other words: + - elements e, x of edges_out with x=0 are in last_tree but not next_tree + - elements e, x of edges_in with x=0 are in next_tree but not last_tree + - elements e, x of edges_out with x=1 are in both trees, + and hence don't count for parent_out + - elements e, x of edges_in with x=1 are in neither, + and hence don't count for parent_in + - elements e, x for edges_out with x=2 have just been added, and so ought + to count towards the next tree, but we have to put them in edges out + because they'll be removed next time. + Notes: + - things having to do with last_tree do not change, + but things having to do with next_tree might change as we go along + - parent_out and parent_in do not refer to the *entire* last/next_tree, + but rather to *only* the edges_in/edges_out + Edges in can have one of three things happen to them: + 1. they get added to the next tree, as usual; + 2. they get postponed to the tree after the next tree, + and are thus part of edges_in again next time; + 3. they get postponed but run out of span so they dissapear entirely. + Edges out are similarly of four varieties: + 0. they are also in case (3) of edges_in, i.e., their extent was modified + when they were in edges_in so that they now have left=right; + 1. they get removed from the last tree, as usual; + 2. they get extended to the next tree, + and are thus part of edges_out again next time; + 3. they are in fact a newly added edge, and so are part of edges_out next time. + """ + self.ts = ts + self.edges = ts.tables.edges.copy() + self.new_left = ts.edges_left.copy() + self.new_right = ts.edges_right.copy() + self.last_degree = np.full(ts.num_nodes, 0, dtype="int") + self.next_degree = np.full(ts.num_nodes, 0, dtype="int") + self.parent_out = np.full(ts.num_nodes, -1, dtype="int") + self.parent_in = np.full(ts.num_nodes, -1, dtype="int") + self.not_sample = [not n.is_sample() for n in ts.nodes()] + self.next_nodes_edge = np.full(ts.num_nodes, -1, dtype="int") + self.last_nodes_edge = np.full(ts.num_nodes, -1, dtype="int") + + if forwards: + self.direction = 1 + # in C we can just modify these in place, but in + # python they are (silently) immutable + self.near_side = list(self.new_left) + self.far_side = list(self.new_right) + else: + self.direction = -1 + self.near_side = list(self.new_right) + self.far_side = list(self.new_left) + + self.edges_out = [] + self.edges_in = [] + + def print_state(self): + print("~~~~~~~~~~~~~~~~~~~~~~~~") + print("edges in:", self.edges_in) + print("parent out:") + for j, pj in enumerate(self.parent_out): + print(f" {j}: {pj}") + print("parent in:") + for j, pj in enumerate(self.parent_in): + print(f" {j}: {pj}") + print("edges out:", self.edges_out) + print("parent out:", self.parent_out) + print("last nodes edge:") + for j, ej in enumerate(self.last_nodes_edge): + print( + f" {j}: {ej}, " + + ( + "(null)" + if ej == -1 + else ( + f"({self.edges.child[ej]}->{self.edges.parent[ej]}, " + "{self.near_side[ej]}-{self.far_side[ej]}" + ) + ) + ) + for e, _ in self.edges_out: + print( + "edge out: ", + "e =", + e, + "c =", + self.edges.child[e], + "p =", + self.edges.parent[e], + self.near_side[e], + self.far_side[e], + ) + + def next_tree(self, tree_pos): + # Clear out non-extended or postponed edges: + # Note: maintaining parent_out is a bit tricky, because + # if an edge from p->c has been extended, entirely replacing + # another edge from p'->c, then both edges may be in edges_out, + # and we only want to include the *first* one. + + for e, x in self.edges_out: + self.parent_out[self.edges.child[e]] = tskit.NULL + if x > 1: + # this is needed to catch newly-created edges + self.last_nodes_edge[self.edges.child[e]] = e + self.last_degree[self.edges.child[e]] += 1 + self.last_degree[self.edges.parent[e]] += 1 + elif x == 0 and self.near_side[e] != self.far_side[e]: + self.last_nodes_edge[self.edges.child[e]] = tskit.NULL + self.last_degree[self.edges.child[e]] -= 1 + self.last_degree[self.edges.parent[e]] -= 1 + tmp = [] + for e, x in self.edges_out: + if x > 0: + tmp.append([e, 0]) + self.edges_out = tmp + for e, x in self.edges_in: + self.parent_in[self.edges.child[e]] = tskit.NULL + if x == 0 and self.near_side[e] != self.far_side[e]: + assert self.last_nodes_edge[self.edges.child[e]] == tskit.NULL + self.last_nodes_edge[self.edges.child[e]] = e + self.last_degree[self.edges.child[e]] += 1 + self.last_degree[self.edges.parent[e]] += 1 + tmp = [] + for e, x in self.edges_in: + if x > 0: + tmp.append([e, 0]) + self.edges_in = tmp + + # done cleanup from last tree transition; + # now we update the state to reflect the current tree transition + for j in range( + tree_pos.out_range.start, tree_pos.out_range.stop, self.direction + ): + e = tree_pos.out_range.order[j] + if (self.parent_out[self.edges.child[e]] == tskit.NULL) and ( + self.near_side[e] != self.far_side[e] + ): + self.edges_out.append([e, 0]) + + for e, _ in self.edges_out: + self.parent_out[self.edges.child[e]] = self.edges.parent[e] + self.next_nodes_edge[self.edges.child[e]] = tskit.NULL + self.next_degree[self.edges.child[e]] -= 1 + self.next_degree[self.edges.parent[e]] -= 1 + + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, self.direction): + e = tree_pos.in_range.order[j] + self.edges_in.append([e, 0]) + + for e, _ in self.edges_in: + self.parent_in[self.edges.child[e]] = self.edges.parent[e] + assert self.next_nodes_edge[self.edges.child[e]] == tskit.NULL + self.next_nodes_edge[self.edges.child[e]] = e + self.next_degree[self.edges.child[e]] += 1 + self.next_degree[self.edges.parent[e]] += 1 + + def check_state_at(self, pos, before, degree, nodes_edge): + # if before=True then we construct the state at epsilon-on-near-side-of `pos`, + # otherwise, at epsilon-on-far-side-of `pos`. + check_degree = np.zeros(self.ts.num_nodes, dtype="int") + check_nodes_edge = np.full(self.ts.num_nodes, -1, dtype="int") + assert len(self.near_side) == self.edges.num_rows + assert len(self.far_side) == self.edges.num_rows + for j, (e, l, r) in enumerate(zip(self.edges, self.near_side, self.far_side)): + overlaps = (l != r) and ( + ((pos - l) * (r - pos) > 0) + or (r == pos and before) + or (l == pos and not before) + ) + if overlaps: + check_degree[e.child] += 1 + check_degree[e.parent] += 1 + assert check_nodes_edge[e.child] == tskit.NULL + check_nodes_edge[e.child] = j + np.testing.assert_equal(check_nodes_edge, nodes_edge) + np.testing.assert_equal(check_degree, degree) + + def check_parent(self, parent, edge_ids): + temp_parent = np.full(self.ts.num_nodes, -1, dtype="int") + for j in edge_ids: + c = self.edges.child[j] + p = self.edges.parent[j] + temp_parent[c] = p + np.testing.assert_equal(temp_parent, parent) + + def check_state(self, here): + for e, x in self.edges_in: + assert x == 0 + assert self.near_side[e] != self.far_side[e] + for e, x in self.edges_out: + assert x == 0 + assert self.near_side[e] != self.far_side[e] + self.check_state_at(here, False, self.next_degree, self.next_nodes_edge) + self.check_state_at(here, True, self.last_degree, self.last_nodes_edge) + self.check_parent(self.parent_in, [j for j, x in self.edges_in if x == 0]) + self.check_parent(self.parent_out, [j for j, x in self.edges_out if x == 0]) + + def add_or_extend_edge(self, new_parent, child, left, right): + there = right if (self.direction == 1) else left + old_edge = self.next_nodes_edge[child] + if old_edge != tskit.NULL: + old_parent = self.edges.parent[old_edge] + else: + old_parent = tskit.NULL + if new_parent != old_parent: + # if our new edge is in edges_out, it should be extended + if self.parent_out[child] == new_parent: + e_out = self.last_nodes_edge[child] + assert e_out >= 0 + assert self.edges.child[e_out] == child + assert self.edges.parent[e_out] == new_parent + self.far_side[e_out] = there + assert self.near_side[e_out] != self.far_side[e_out] + for ex_out in self.edges_out: + if ex_out[0] == e_out: + break + assert ex_out[0] == e_out + ex_out[1] = 1 + else: + e_out = self.add_edge(new_parent, child, left, right) + self.edges_out.append([e_out, 2]) + # If we're replacing the edge above this node, it must be in edges_in; + # note that this assertion excludes the case that we're interrupting + # an existing edge. + assert (self.next_nodes_edge[child] == tskit.NULL) or ( + self.next_nodes_edge[child] in [e for e, _ in self.edges_in] + ) + self.next_nodes_edge[child] = e_out + self.next_degree[child] += 1 + self.next_degree[new_parent] += 1 + self.parent_out[child] = tskit.NULL + if old_edge != tskit.NULL: + for ex_in in self.edges_in: + e_in = ex_in[0] + if e_in == old_edge and (ex_in[1] == 0): + self.near_side[e_in] = there + if self.far_side[e_in] != there: + ex_in[1] = 1 + self.next_nodes_edge[child] = tskit.NULL + self.next_degree[child] -= 1 + self.next_degree[self.parent_in[child]] -= 1 + self.parent_in[child] = tskit.NULL + + def add_edge(self, parent, child, left, right): + new_id = self.edges.add_row(parent=parent, child=child, left=left, right=right) + # this appending should not be necessary in C + if self.direction == 1: + self.near_side.append(left) + self.far_side.append(right) + else: + self.near_side.append(right) + self.far_side.append(left) + return new_id + + def mergeable(self, c): + # returns a finite number of new edges needed + # if the paths in parent_in and parent_out + # up through nodes that aren't in the other tree + # end at the same place and don't have conflicting times; + # otherwise, returns Inf + p_out = self.parent_out[c] + p_in = self.parent_in[c] + t_out = np.inf if p_out == tskit.NULL else self.ts.nodes_time[p_out] + t_in = np.inf if p_in == tskit.NULL else self.ts.nodes_time[p_in] + child = c + num_new_edges = 0 + num_extended = 0 + while True: + climb_in = ( + p_in != tskit.NULL + and self.last_degree[p_in] == 0 + and self.not_sample[p_in] + and t_in < t_out + ) + climb_out = ( + p_out != tskit.NULL + and self.next_degree[p_out] == 0 + and self.not_sample[p_out] + and t_out < t_in + ) + if climb_in: + if self.parent_in[child] != p_in and self.parent_out[child] != p_in: + num_new_edges += 1 + child = p_in + p_in = self.parent_in[p_in] + t_in = np.inf if p_in == tskit.NULL else self.ts.nodes_time[p_in] + elif climb_out: + if self.parent_in[child] != p_out and self.parent_out[child] != p_out: + num_new_edges += 1 + child = p_out + p_out = self.parent_out[p_out] + t_out = np.inf if p_out == tskit.NULL else self.ts.nodes_time[p_out] + num_extended += 1 + else: + break + if num_extended == 0 or p_in != p_out or p_in == tskit.NULL: + num_new_edges = np.inf + return num_new_edges + + def merge_paths(self, c, left, right): + p_out = self.parent_out[c] + p_in = self.parent_in[c] + t_out = self.ts.nodes_time[p_out] + t_in = self.ts.nodes_time[p_in] + child = c + while True: + climb_in = ( + p_in != tskit.NULL + and self.last_degree[p_in] == 0 + and self.not_sample[p_in] + and t_in < t_out + ) + climb_out = ( + p_out != tskit.NULL + and self.next_degree[p_out] == 0 + and self.not_sample[p_out] + and t_out < t_in + ) + if climb_in: + self.add_or_extend_edge(p_in, child, left, right) + child = p_in + p_in = self.parent_in[p_in] + t_in = np.inf if p_in == tskit.NULL else self.ts.nodes_time[p_in] + elif climb_out: + self.add_or_extend_edge(p_out, child, left, right) + child = p_out + p_out = self.parent_out[p_out] + t_out = np.inf if p_out == tskit.NULL else self.ts.nodes_time[p_out] + else: + break + assert p_out == p_in + self.add_or_extend_edge(p_out, child, left, right) + + def extend_haplotypes(self): + tree_pos = tsutil.TreePosition(self.ts) + if self.direction == 1: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + while valid: + left, right = tree_pos.interval + # there = right if self.direction == 1 else left + here = left if self.direction == 1 else right + self.next_tree(tree_pos) + self.check_state(here) + max_new_edges = 0 + next_max_new_edges = np.inf + while max_new_edges < np.inf: + for e_in, x in self.edges_in: + if x == 0: + c = self.edges.child[e_in] + assert self.next_degree[c] > 0 + if self.last_degree[c] > 0: + ne = self.mergeable(c) + if ne <= max_new_edges: + self.merge_paths(c, left, right) + else: + next_max_new_edges = min(ne, next_max_new_edges) + max_new_edges = next_max_new_edges + next_max_new_edges = np.inf + # end of loop, next tree + if self.direction == 1: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + if self.direction == 1: + self.new_left = np.array(self.near_side) + self.new_right = np.array(self.far_side) + else: + self.new_right = np.array(self.near_side) + self.new_left = np.array(self.far_side) + # Get rid of adjacent, identical edges + keep = np.full(self.edges.num_rows, True, dtype=bool) + for j in range(self.edges.num_rows - 1): + if ( + self.edges.parent[j] == self.edges.parent[j + 1] + and self.edges.child[j] == self.edges.child[j + 1] + and self.new_right[j] == self.new_left[j + 1] + ): + self.new_right[j] = self.new_right[j + 1] + self.new_left[j + 1] = self.new_right[j + 1] + for j in range(self.edges.num_rows): + left = self.new_left[j] + right = self.new_right[j] + if left < right: + self.edges[j] = self.edges[j].replace(left=left, right=right) + else: + keep[j] = False + self.edges.keep_rows(keep) + + +def extend_haplotypes(ts, max_iter=10): + tables = ts.dump_tables() + mutations = tables.mutations.copy() + tables.mutations.clear() + + last_num_edges = ts.num_edges + for _ in range(max_iter): + for forwards in [True, False]: + extender = HaplotypeExtender(ts, forwards=forwards) + extender.extend_haplotypes() + tables.edges.replace_with(extender.edges) + tables.sort() + tables.build_index() + ts = tables.tree_sequence() + if ts.num_edges == last_num_edges: + break + else: + last_num_edges = ts.num_edges + + tables = ts.dump_tables() + mutations = _slide_mutation_nodes_up(ts, mutations) + tables.mutations.replace_with(mutations) + tables.sort() + ts = tables.tree_sequence() + return ts + + +def _path_pairs(tree): + for c in tree.postorder(): + p = tree.parent(c) + while p != tskit.NULL: + yield (c, p) + p = tree.parent(p) + + +def _path_up(c, p, tree, include_parent=False): + # path from c up to p in tree, not including c or p + c = tree.parent(c) + while c != p and c != tskit.NULL: + yield c + c = tree.parent(c) + assert c == p + if include_parent: + yield p + + +def _path_up_pairs(c, p, tree, others): + # others should be a list of nodes + otherdict = {tree.time(n): n for n in others} + ot = min(otherdict) + for n in _path_up(c, p, tree, include_parent=True): + nt = tree.time(n) + while ot < nt: + on = otherdict.pop(ot) + yield c, on + c = on + if len(otherdict) > 0: + ot = min(otherdict) + else: + ot = np.inf + yield c, n + c = n + assert n == p + assert len(otherdict) == 0 + + +def _path_overlaps(c, p, tree1, tree2): + for n in _path_up(c, p, tree1): + if n in tree2.nodes(): + return True + return False + + +def _paths_mergeable(c, p, tree1, tree2): + # checks that the nodes between c and p in each tree + # are not present in the other tree + # and their sets of times are disjoint + nodes1 = set(tree1.nodes()) + nodes2 = set(tree2.nodes()) + assert c in nodes1, f"child node {c} not in tree1" + assert p in nodes1, f"parent node {p} not in tree1" + assert c in nodes2, f"child node {c} not in tree2" + assert p in nodes2, f"parent node {p} not in tree2" + path1 = set(_path_up(c, p, tree1)) + path2 = set(_path_up(c, p, tree2)) + times1 = {tree1.time(n) for n in path1} + times2 = {tree2.time(n) for n in path2} + return ( + (not _path_overlaps(c, p, tree1, tree2)) + and (not _path_overlaps(c, p, tree2, tree1)) + and len(times1.intersection(times2)) == 0 + ) + + +def _extend_nodes(ts, interval, extendable): + tables = ts.dump_tables() + tables.edges.clear() + mutations = tables.mutations.copy() + tables.mutations.clear() + left, right = interval + # print("=================") + # print("extending", left, right) + extend_above = {} # gives the new child->parent mapping + todo_edges = np.repeat(True, ts.num_edges) + tree = ts.at(left) + for c, p, others in extendable: + # print("c:", c, "p:", p, "others:", others) + others_not_done_yet = set(others) - set(extend_above) + if len(others_not_done_yet) > 0: + for cn, pn in _path_up_pairs(c, p, tree, others_not_done_yet): + if cn not in extend_above: + assert cn not in extend_above + extend_above[cn] = pn + for c, p in extend_above.items(): + e = tree.edge(c) + if e == tskit.NULL or ts.edge(e).parent != p: + # print("adding", c, p) + tables.edges.add_row(child=c, parent=p, left=left, right=right) + if e != tskit.NULL: + edge = ts.edge(e) + # adjust endpoints on existing edge + for el, er in [ + (max(edge.left, right), edge.right), + (edge.left, min(edge.right, left)), + ]: + if el < er: + # print("replacing", edge, el, er) + tables.edges.append(edge.replace(left=el, right=er)) + todo_edges[e] = False + for todo, edge in zip(todo_edges, ts.edges()): + if todo: + # print("retaining", edge) + tables.edges.append(edge) + tables.sort() + ts = tables.tree_sequence() + mutations = _slide_mutation_nodes_up(ts, mutations) + tables.mutations.replace_with(mutations) + tables.sort() + return tables.tree_sequence() + + +def _naive_pass(ts, direction): + assert direction in (-1, +1) + num_trees = ts.num_trees + if direction == +1: + indexes = range(0, num_trees - 1, 1) + else: + indexes = range(num_trees - 1, 0, -1) + for tj in indexes: + extendable = [] + this_tree = ts.at_index(tj) + next_tree = ts.at_index(tj + direction) + # print("-----------", this_tree.index) + # print(this_tree.draw_text()) + # print(next_tree.draw_text()) + for c, p in _path_pairs(this_tree): + if ( + p != this_tree.parent(c) + and p in next_tree.nodes() + and c in next_tree.nodes(p) + ): + # print(c, p, " and ", list(next_tree.nodes(p))) + if _paths_mergeable(c, p, this_tree, next_tree): + extendable.append((c, p, list(_path_up(c, p, this_tree)))) + # print("extending to", extendable) + ts = _extend_nodes(ts, next_tree.interval, extendable) + assert num_trees == ts.num_trees + return ts + + +def naive_extend_haplotypes(ts, max_iter=20): + for _ in range(max_iter): + ets = _naive_pass(ts, +1) + ets = _naive_pass(ets, -1) + if ets == ts: + break + ts = ets + return ts + + +class TestExtendThings: + """ + Common utilities in the two classes below. + """ + + def verify_simplify_equality(self, ts, ets): + assert ts.num_nodes == ets.num_nodes + assert ts.num_samples == ets.num_samples + t = ts.simplify().tables + et = ets.simplify().tables + et.assert_equals(t, ignore_provenance=True) + assert np.all(ts.genotype_matrix() == ets.genotype_matrix()) + + def naive_verify(self, ts): + ets = naive_extend_haplotypes(ts) + self.verify_simplify_equality(ts, ets) + + +class TestExtendHaplotypes(TestExtendThings): + """ + Test the 'extend_haplotypes' method. + """ + + def get_example1(self): + # 15.00| | 13 | | + # | | | | | + # 12.00| 10 | 10 | 10 | + # | +-+-+ | +-+-+ | +-+-+ | + # 10.00| 8 | | | | | 8 | | + # | | | | | | | ++-+ | | + # 8.00 | | | | 11 12 | | | | | + # | | | | | | | | | | | + # 6.00 | | | | 7 | | | | | | + # | | | | | | | | | | | + # 4.00 | 6 9 | | | | | | | | + # | | | | | | | | | | | + # 1.00 | 4 5 | 4 5 | 4 | 5 | + # | +++ +++ | +++ +++ | +++ | | | + # 0.00 | 0 1 2 3 | 0 1 2 3 | 0 1 2 3 | + # 0 3 6 9 + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 1, + 5: 1, + 6: 4, + 7: 6, + 8: 10, + 9: 4, + 10: 12, + 11: 8, + 12: 8, + 13: 15, + } + # (p,c,l,r) + edges = [ + (4, 0, 0, 9), + (4, 1, 0, 9), + (5, 2, 0, 6), + (5, 3, 0, 9), + (6, 4, 0, 3), + (9, 5, 0, 3), + (7, 4, 3, 6), + (11, 7, 3, 6), + (12, 5, 3, 6), + (8, 2, 6, 9), + (8, 4, 6, 9), + (8, 6, 0, 3), + (10, 5, 6, 9), + (10, 8, 0, 3), + (10, 8, 6, 9), + (10, 9, 0, 3), + (10, 11, 3, 6), + (10, 12, 3, 6), + (13, 10, 3, 6), + ] + extended_edges = [ + (4, 0, 0.0, 9.0), + (4, 1, 0.0, 9.0), + (5, 2, 0.0, 6.0), + (5, 3, 0.0, 9.0), + (6, 4, 0.0, 9.0), + (9, 5, 0.0, 9.0), + (7, 6, 0.0, 9.0), + (11, 7, 0.0, 9.0), + (12, 9, 0.0, 9.0), + (8, 2, 6.0, 9.0), + (8, 11, 0.0, 9.0), + (10, 8, 0.0, 9.0), + (10, 12, 0.0, 9.0), + (13, 10, 3.0, 6.0), + ] + samples = list(np.arange(4)) + tables = tskit.TableCollection(sequence_length=9) + for ( + n, + t, + ) in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + tables.edges.clear() + for p, c, l, r in extended_edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = tables.tree_sequence() + assert ts.num_edges == 19 + assert ets.num_edges == 14 + return ts, ets + + def get_example2(self): + # 12.00| | 21 | | + # | | +----+-----+ | | + # 11.00| 20 | | | | 20 | + # | +----+---+ | | | | +----+---+ | + # 10.00| | 19 | | 19 | | 19 | + # | | ++-+ | | +-+-+ | | ++-+ | + # 9.00 | 18 | | | 18 | | | 18 | | | + # | +--+--+ | | | +--+--+ | | | +--+--+ | | | + # 8.00 | | | | | | | | | | | 17 | | | | + # | | | | | | | | | | | +-+-+ | | | | + # 7.00 | | | 16 | | | | 16 | | | | | | | | + # | | | +++ | | | | +-++ | | | | | | | | + # 6.00 | 15 | | | | | | | | | | | | | | | | | + # | +-+-+ | | | | | | | | | | | | | | | | | + # 5.00 | | | 14 | | | | | 14 | | | | | | 14 | | | + # | | | ++-+ | | | | | ++-+ | | | | | | ++-+ | | | + # 4.00 | 13 | | | | | | | 13 | | | | | | 13 | | | | | | + # | ++-+ | | | | | | | ++-+ | | | | | | ++-+ | | | | | | + # 3.00 | | | | | | | | | | | | | | | 12 | | | | | | | 12 | | + # | | | | | | | | | | | | | | | +++ | | | | | | | +++ | | + # 2.00 | 11 | | | | | | | | 11 | | | | | | | | 11 | | | | | | | | + # | +++ | | | | | | | | +++ | | | | | | | | +++ | | | | | | | | + # 1.00 | | | | | 10 | | | | | | | | 10 | | | | | | | | | | 10 | | | | | + # | | | | | +++ | | | | | | | | +++ | | | | | | | | | | +++ | | | | | + # 0.00 | 0 7 4 9 2 5 6 1 3 8 | 0 7 4 2 5 6 1 3 9 8 | 0 7 4 1 2 5 6 3 9 8 | + # 0 3 6 9 + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 0, + 6: 0, + 7: 0, + 8: 0, + 9: 0, + 10: 1, + 11: 2, + 12: 3, + 13: 4, + 14: 5, + 15: 6, + 16: 7, + 17: 8, + 18: 9, + 19: 10, + 20: 11, + 21: 12, + } + # (p,c,l,r) + edges = [ + (10, 2, 0, 9), + (10, 5, 0, 9), + (11, 0, 0, 9), + (11, 7, 0, 9), + (12, 3, 3, 9), + (12, 9, 3, 9), + (13, 4, 0, 9), + (13, 11, 0, 9), + (14, 6, 0, 9), + (14, 10, 0, 9), + (15, 9, 0, 3), + (15, 13, 0, 3), + (16, 1, 0, 6), + (16, 3, 0, 3), + (16, 12, 3, 6), + (17, 1, 6, 9), + (17, 13, 6, 9), + (18, 13, 3, 6), + (18, 14, 0, 9), + (18, 15, 0, 3), + (18, 17, 6, 9), + (19, 8, 0, 9), + (19, 12, 6, 9), + (19, 16, 0, 6), + (20, 18, 0, 3), + (20, 18, 6, 9), + (20, 19, 0, 3), + (20, 19, 6, 9), + (21, 18, 3, 6), + (21, 19, 3, 6), + ] + extended_edges = [ + (10, 2, 0.0, 9.0), + (10, 5, 0.0, 9.0), + (11, 0, 0.0, 9.0), + (11, 7, 0.0, 9.0), + (12, 3, 0.0, 9.0), + (12, 9, 3.0, 9.0), + (13, 4, 0.0, 9.0), + (13, 11, 0.0, 9.0), + (14, 6, 0.0, 9.0), + (14, 10, 0.0, 9.0), + (15, 9, 0.0, 3.0), + (15, 13, 0.0, 9.0), + (16, 1, 0.0, 6.0), + (16, 12, 0.0, 9.0), + (17, 1, 6.0, 9.0), + (17, 15, 0.0, 9.0), + (18, 14, 0.0, 9.0), + (18, 17, 0.0, 9.0), + (19, 8, 0.0, 9.0), + (19, 16, 0.0, 9.0), + (20, 18, 0.0, 3.0), + (20, 18, 6.0, 9.0), + (20, 19, 0.0, 3.0), + (20, 19, 6.0, 9.0), + (21, 18, 3.0, 6.0), + (21, 19, 3.0, 6.0), + ] + samples = list(np.arange(10)) + tables = tskit.TableCollection(sequence_length=9) + for ( + n, + t, + ) in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + tables.edges.clear() + for p, c, l, r in extended_edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = tables.tree_sequence() + assert ts.num_edges == 30 + assert ets.num_edges == 26 + return ts, ets + + def get_example3(self): + # Here is the full tree; extend edges should be able to + # recover all unary nodes after simplification: + # + # 9 9 9 9 + # +-+-+ +--+--+ +---+---+ +-+-+--+ + # 8 | 8 | 8 | | 8 | | | + # | | +-+-+ | | | | | | | | + # 7 | | 7 | | 7 | | | | 7 + # +-+-+ | | +-++ | | +-++ | | | | | + # 6 | | | | 6 | | | 6 | | | | 6 + # +-++ | | | | | | | | | | | | | | + # 1 0 2 3 1 2 0 3 1 2 0 3 1 2 3 0 + # +++ +++ +++ +++ + # 4 5 4 5 4 5 4 5 + # + samples = [0, 1, 2, 3, 4, 5] + node_times = [1, 1, 1, 1, 0, 0, 2, 3, 4, 5] + # (p, c, l, r) + edges = [ + (0, 4, 0, 10), + (0, 5, 0, 10), + (6, 0, 0, 10), + (6, 1, 0, 3), + (7, 2, 0, 7), + (7, 6, 0, 10), + (8, 1, 3, 10), + (8, 7, 0, 5), + (9, 2, 7, 10), + (9, 3, 0, 10), + (9, 7, 5, 10), + (9, 8, 0, 10), + ] + tables = tskit.TableCollection(sequence_length=10) + for n, t in enumerate(node_times): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = tables.tree_sequence() + ts = ets.simplify() + assert ts.num_edges == 16 + assert ets.num_edges == 12 + return ts, ets + + def get_example4(self): + # 7 and 8 should be extended to the whole sequence; + # and also 5 to the second tree + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # | | 7 | | 8 | | + # | | ++-+ | | +-++ | | + # 4 5 4 | | 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + node_times = (0, 0, 0, 0, 1, 1, 3, 2, 2) + samples = (0, 1, 2, 3) + # (p, c, l, r) + extended_edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 10), + (7, 2, 2, 5), + (7, 4, 0, 10), + (8, 1, 5, 7), + (8, 5, 0, 10), + (6, 7, 0, 10), + (6, 8, 0, 10), + ] + edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (7, 2, 2, 5), + (7, 4, 2, 5), + (8, 1, 5, 7), + (8, 5, 5, 7), + (6, 3, 2, 5), + (6, 4, 0, 2), + (6, 4, 5, 10), + (6, 5, 0, 2), + (6, 5, 7, 10), + (6, 7, 2, 5), + (6, 8, 5, 7), + ] + tables = tskit.TableCollection(sequence_length=10) + tables.sort() + for n, t in enumerate(node_times): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + tables.edges.clear() + for p, c, l, r in extended_edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = tables.tree_sequence() + assert ts.num_edges == 18 + assert ets.num_edges == 12 + return ts, ets + + def get_example5(self): + # This is an example where new edges are added + # on both forwards and back passes + # 4.00┊ ┊ 4 ┊ 4 ┊ + # ┊ ┊ ┃ ┊ ┃ ┊ + # 3.00┊ 2 ┊ ┃ ┊ 2 ┊ + # ┊ ┃ ┊ ┃ ┊ ┃ ┊ + # 2.00┊ ┃ ┊ 3 ┊ ┃ ┊ + # ┊ ┃ ┊ ┃ ┊ ┃ ┊ + # 1.00┊ 1 ┊ ┃ ┊ ┃ ┊ + # ┊ ┃ ┊ ┃ ┊ ┃ ┊ + # 0.00┊ 0 ┊ 0 ┊ 0 ┊ + # 0 2 4 6 + node_times = (0, 1, 3, 2, 4) + samples = (0,) + # (p, c, l, r) + edges = [ + (1, 0, 0, 2), + (2, 1, 0, 2), + (3, 0, 2, 4), + (4, 3, 2, 4), + (4, 2, 4, 6), + (2, 0, 4, 6), + ] + extended_edges = [ + (1, 0, 0, 6), + (3, 1, 0, 6), + (2, 3, 0, 6), + (4, 2, 2, 6), + ] + site_positions = (3,) + # site, node, derived_state, time + mutations = [ + (0, 4, 5, 4.5), + (0, 3, 4, 3.5), + (0, 3, 3, 2.5), + (0, 0, 2, 1.5), + (0, 0, 1, 0.5), + ] + extended_mutations_node = [4, 2, 3, 1, 0] + tables = tskit.TableCollection(sequence_length=6) + for n, t in enumerate(node_times): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + for x in site_positions: + tables.sites.add_row(ancestral_state="0", position=x) + for s, n, d, t in mutations: + tables.mutations.add_row(site=s, node=n, derived_state=str(d), time=t) + tables.sort() + tables.build_index() + tables.compute_mutation_parents() + ts = tables.tree_sequence() + tables.edges.clear() + for p, c, l, r in extended_edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + tables.sort() + tables.mutations.clear() + for (s, _, d, t), n in zip(mutations, extended_mutations_node): + tables.mutations.add_row(site=s, node=n, derived_state=str(d), time=t) + tables.build_index() + tables.compute_mutation_parents() + ets = tables.tree_sequence() + return ts, ets + + def get_example(self, j): + if j == 1: + ts, ets = self.get_example1() + elif j == 2: + ts, ets = self.get_example2() + elif j == 3: + ts, ets = self.get_example3() + elif j == 4: + ts, ets = self.get_example4() + elif j == 5: + ts, ets = self.get_example5() + else: + raise ValueError + return ts, ets + + def verify_extend_haplotypes(self, ts, max_iter=10): + ets = ts.extend_haplotypes(max_iter=max_iter) + py_ets = extend_haplotypes(ts, max_iter=max_iter) + ets.tables.assert_equals(py_ets.tables, ignore_provenance=True) + self.verify_simplify_equality(ts, ets) + + def test_runs(self): + ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) + self.verify_extend_haplotypes(ts) + self.naive_verify(ts) + + @pytest.mark.parametrize("j", [1, 2, 3, 4, 5]) + def test_example(self, j): + ts, correct_ets = self.get_example(j) + test_ets = ts.extend_haplotypes() + test_ets.tables.assert_equals(correct_ets.tables, ignore_provenance=True) + self.verify_extend_haplotypes(ts) + self.naive_verify(ts) + + @pytest.mark.parametrize("j", [1, 2, 3, 4, 5]) + def test_redundant_breakpoitns(self, j): + ts, correct_ets = self.get_example(j) + ts = tsutil.insert_redundant_breakpoints(ts) + test_ets = ts.extend_haplotypes() + test_ets.tables.assert_equals(correct_ets.tables, ignore_provenance=True) + self.verify_extend_haplotypes(ts) + self.naive_verify(ts) + + def test_migrations_disallowed(self): + ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) + tables = ts.dump_tables() + tables.populations.add_row() + tables.populations.add_row() + tables.migrations.add_row(0, 1, 0, 0, 1, 0) + ts = tables.tree_sequence() + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" + ): + _ = ts.extend_haplotypes() + + def test_unknown_times(self): + ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126) + tables = ts.dump_tables() + tables.mutations.clear() + for mut in ts.mutations(): + tables.mutations.append(mut.replace(time=tskit.UNKNOWN_TIME)) + ts = tables.tree_sequence() + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME" + ): + _ = ts.extend_haplotypes() + + def test_max_iter(self): + ts = msprime.simulate(5, random_seed=126) + with pytest.raises(_tskit.LibraryError, match="positive"): + ets = ts.extend_haplotypes(max_iter=0) + with pytest.raises(_tskit.LibraryError, match="positive"): + ets = ts.extend_haplotypes(max_iter=-1) + ets = ts.extend_haplotypes(max_iter=1) + et = ets.extend_haplotypes(max_iter=1).dump_tables() + eet = ets.extend_haplotypes(max_iter=2).dump_tables() + eet.assert_equals(et) + + def test_very_simple(self): + samples = [0] + node_times = [0, 1, 2, 3] + # (p, c, l, r) + edges = [ + (1, 0, 0, 1), + (2, 0, 1, 2), + (2, 1, 0, 1), + (3, 0, 2, 3), + (3, 2, 0, 2), + ] + correct_edges = [ + (1, 0, 0, 3), + (2, 1, 0, 3), + (3, 2, 0, 3), + ] + tables = tskit.TableCollection(sequence_length=3) + for n, t in enumerate(node_times): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + ets = extend_haplotypes(ts) + etables = ets.tables + correct_tables = etables.copy() + correct_tables.edges.clear() + for p, c, l, r in correct_edges: + correct_tables.edges.add_row(parent=p, child=c, left=l, right=r) + etables.assert_equals(correct_tables, ignore_provenance=True) + self.naive_verify(ts) + + def test_internal_samples(self): + # Now we should have the same but not extend 5 (where * is), + # since 5 is a sample; nor 8 because it's extension depends on 5 + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # 7 * 7 * 7 8 7 8 + # | | ++-+ | | +-++ | | + # 4 5 4 | * 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + # + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 1.0, + 5: 1.0, + 6: 3.0, + 7: 2.0, + 8: 2.0, + } + # (p, c, l, r) + edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (7, 2, 2, 5), + (7, 4, 0, 10), + (8, 1, 5, 7), + (8, 5, 5, 10), + (6, 3, 2, 5), + (6, 5, 0, 2), + (6, 7, 0, 10), + (6, 8, 5, 10), + ] + tables = tskit.TableCollection(sequence_length=10) + samples = [0, 1, 2, 3, 5] + for n, t in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + ets = extend_haplotypes(ts) + # nothing should have happened + ets.tables.assert_equals(tables) + self.verify_extend_haplotypes(ts) + self.naive_verify(ts) + + @pytest.mark.parametrize("seed", [3, 4, 5, 6]) + def test_wf(self, seed): + tables = wf.wf_sim(N=6, ngens=9, num_loci=100, deep_history=False, seed=seed) + tables.sort() + ts = tables.tree_sequence().simplify() + self.verify_extend_haplotypes(ts) + self.naive_verify(ts) + + +class TestExamples(TestExtendThings): + """ + Compare the ts method with local implementation. + """ + + def check(self, ts): + if np.any(tskit.is_unknown_time(ts.mutations_time)): + tables = ts.dump_tables() + tables.compute_mutation_times() + ts = tables.tree_sequence() + py_ets = extend_haplotypes(ts) + self.verify_simplify_equality(ts, py_ets) + lib_ts = ts.extend_haplotypes() + lib_ts.tables.assert_equals(py_ets.tables) + assert np.all(ts.genotype_matrix() == lib_ts.genotype_matrix()) + sts = ts.simplify() + lib_sts = lib_ts.simplify() + lib_sts.tables.assert_equals(sts.tables, ignore_provenance=True) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_suite_examples_defaults(self, ts): + if ts.num_migrations == 0: + self.check(ts) + else: + pass + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" + ): + _ = ts.extend_haplotypes() + + @pytest.mark.parametrize("n", [3, 4, 5]) + def test_all_trees_ts(self, n): + ts = tsutil.all_trees_ts(n) + self.check(ts) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 908c0dcb1d..dc0120fb8a 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1496,21 +1496,28 @@ def test_time_units(self): ts.load_tables(tables) assert ts.get_time_units() == value - def test_extend_edges_bad_args(self): + def test_extend_haplotypes(self): + ts = self.get_example_tree_sequence(6) + ets2 = ts.extend_haplotypes(2) + ets4 = ts.extend_haplotypes(4) + assert ets2.get_num_nodes() == ts.get_num_nodes() + assert ets4.get_num_nodes() == ts.get_num_nodes() + + def test_extend_haplotypes_bad_args(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): - ts1.extend_edges() + ts1.extend_haplotypes() with pytest.raises(TypeError, match="an integer"): - ts1.extend_edges("sdf") + ts1.extend_haplotypes("sdf") with pytest.raises(_tskit.LibraryError, match="positive"): - ts1.extend_edges(0) + ts1.extend_haplotypes(0) with pytest.raises(_tskit.LibraryError, match="positive"): - ts1.extend_edges(-1) + ts1.extend_haplotypes(-1) tsm = self.get_example_migration_tree_sequence() with pytest.raises( _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" ): - tsm.extend_edges(1) + tsm.extend_haplotypes(1) @pytest.mark.parametrize( "stat_method_name", diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 35f21e14d4..52d0bd380f 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6921,25 +6921,24 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None): tables.delete_older(time) return tables.tree_sequence() - def extend_edges(self, max_iter=10): + def extend_haplotypes(self, max_iter=10): """ Returns a new tree sequence in which the span covered by ancestral nodes is "extended" to regions of the genome according to the following rule: - If an ancestral segment corresponding to node `n` has parent `p` and - child `c` on some portion of the genome, and on an adjacent segment of - genome `p` is the immediate parent of `c`, then `n` is inserted into the - edge from `p` to `c`. This involves extending the span of the edges - from `p` to `n` and `n` to `c` and reducing the span of the edge from - `p` to `c`. However, any edges whose child node is a sample will not - be modified. - - Since some edges may be removed entirely, this process reduces (or at - least does not increase) the number of edges in the tree sequence. - - *Note:* this is a somewhat experimental operation, and is probably not - what you are looking for. - - The method works by iterating over the genome to look for edges that can + If an ancestral segment corresponding to node `n` has ancestor `p` and + descendant `c` on some portion of the genome, and on an adjacent segment of + genome `p` is still an ancestor of `c`, then `n` is inserted into the + path from `p` to `c`. For instance, if `p` is the parent of `n` and `n` + is the parent of `c`, then the span of the edges from `p` to `n` and + `n` to `c` are extended, and the span of the edge from `p` to `c` is + reduced. Thus, the ancestral haplotype represented by `n` is extended + to a longer span of the genome. However, any edges whose child node is + a sample are not modified. + + Since some edges may be removed entirely, this process usually reduces + the number of edges in the tree sequence. + + The method works by iterating over the genome to look for paths that can be extended in this way; the maximum number of such iterations is controlled by ``max_iter``. @@ -6949,11 +6948,13 @@ def extend_edges(self, max_iter=10): be that distinct recombined segments were passed down separately from `p` to `c`). - If an edge that a mutation falls on is split by this operation, the - mutation's node may need to be moved. This is only unambiguous if the - mutation's time is known, so the method requires known mutation times. - See :meth:`.impute_unknown_mutations_time` if mutation times are - not known. + In the example above, if there was a mutation on the node above `c` + older than the time of `n` in the span into which `n` was extended, + then the mutation will now occur above `n`. So, this operation may change + mutations' nodes (but will not affect genotypes). This is only + unambiguous if the mutation's time is known, so the method requires + known mutation times. See :meth:`.impute_unknown_mutations_time` if + mutation times are not known. The method will not affect the marginal trees (so, if the original tree sequence was simplified, then following up with `simplify` will recover @@ -6968,7 +6969,7 @@ def extend_edges(self, max_iter=10): :rtype: tskit.TreeSequence """ max_iter = int(max_iter) - ll_ts = self._ll_tree_sequence.extend_edges(max_iter) + ll_ts = self._ll_tree_sequence.extend_haplotypes(max_iter) return TreeSequence(ll_ts) def subset(