diff --git a/c/examples/efficient_forward_simulation.c b/c/examples/efficient_forward_simulation.c new file mode 100644 index 0000000000..3e93a1bf4f --- /dev/null +++ b/c/examples/efficient_forward_simulation.c @@ -0,0 +1,289 @@ +#include +#include +#include +#include + +#include + +#define check_tsk_error(val) \ + if (val < 0) { \ + errx(EXIT_FAILURE, "line %d: %s", __LINE__, tsk_strerror(val)); \ + } + +typedef struct { + double left; + double right; + double parent_birth_time; + tsk_id_t parent; + tsk_id_t child; +} birth; + +int +cmp_birth(const void *lhs, const void *rhs) +{ + const birth *clhs = (const birth *) lhs; + const birth *crhs = (const birth *) rhs; + int ret = (clhs->parent_birth_time > crhs->parent_birth_time) + - (crhs->parent_birth_time > clhs->parent_birth_time); + if (ret == 0) { + ret = (clhs->parent > crhs->parent) - (crhs->parent > clhs->parent); + } + return ret; +} + +typedef struct { + birth *births; + tsk_size_t capacity; + tsk_size_t size; +} edge_buffer; + +int +edge_buffer_init(edge_buffer *buffer, tsk_size_t initial_capacity) +{ + int ret = 0; + if (initial_capacity == 0) { + ret = -1; + goto out; + } + buffer->births = (birth *) malloc(initial_capacity * sizeof(birth)); + buffer->capacity = initial_capacity; + buffer->size = 0; +out: + return ret; +} + +void +edge_buffer_realloc(edge_buffer *buffer) +{ + if (buffer->size + 1 >= buffer->capacity) { + buffer->capacity *= 2; + buffer->births + = (birth *) realloc(buffer->births, buffer->capacity * sizeof(birth)); + } +} + +void +edge_buffer_free(edge_buffer *buffer) +{ + if (buffer->births != NULL) { + free(buffer->births); + buffer->births = NULL; + } +} + +void +edge_buffer_buffer_birth(double left, double right, double parent_birth_time, + tsk_id_t parent, tsk_id_t child, edge_buffer *buffer) +{ + edge_buffer_realloc(buffer); + buffer->births[buffer->size].left = left; + buffer->births[buffer->size].right = right; + buffer->births[buffer->size].parent_birth_time = parent_birth_time; + buffer->births[buffer->size].parent = parent; + buffer->births[buffer->size].child = child; + buffer->size += 1; +} + +void +edge_buffer_prep_for_simplification(edge_buffer *buffer) +{ + qsort(buffer->births, (size_t) buffer->size, sizeof(birth), cmp_birth); +} + +/* Overlapping generations are a pain point. + * Nodes that are currently "alive" can span an + * arbitrary range of ages >= 0.0. + * These nodes can have ancestry (edges). + * The children of these edges may not be "alive" (e.g., + * cannot be parents). + * + * We identify the range of such edges in a (simplified) + * table collection and lift them over into our buffer. + */ +void +edge_buffer_post_simplification(tsk_id_t *output_samples, tsk_size_t num_output_samples, + tsk_table_collection_t *tables, edge_buffer *buffer) +{ + /* node times cannot be negative, so this + * is a reasonable floor + */ + double max_alive_time = 0.0; + int64_t i, last_row_to_lift_over = -1; + tsk_size_t moved = 0; + + for (i = 0; i < num_output_samples; ++i) { + max_alive_time = TSK_MAX(tables->nodes.time[output_samples[i]], max_alive_time); + } + for (i = 0; i < tables->edges.num_rows; ++i) { + if (tables->nodes.time[tables->edges.parent[i]] <= max_alive_time) { + last_row_to_lift_over = (int) i; + } + } + if (last_row_to_lift_over > -1) { + for (i = 0; i < (tsk_size_t) last_row_to_lift_over; ++i) { + edge_buffer_buffer_birth(tables->edges.left[i], tables->edges.right[i], + tables->nodes.time[tables->edges.parent[i]], tables->edges.parent[i], + tables->edges.child[i], buffer); + } + for (i = (tsk_size_t) last_row_to_lift_over + 1; i < tables->edges.num_rows; + ++i) { + tables->edges.left[moved] = tables->edges.left[i]; + tables->edges.right[moved] = tables->edges.right[i]; + tables->edges.parent[moved] = tables->edges.parent[i]; + tables->edges.child[moved] = tables->edges.child[i]; + moved += 1; + } + tsk_edge_table_truncate(&tables->edges, moved); + } +} + +void +edge_buffer_clear(edge_buffer *buffer) +{ + buffer->size = 0; +} + +void +simulate( + tsk_table_collection_t *tables, int N, int T, int simplify_interval, double pdeath) +{ + tsk_id_t *buffer, *alive, *deaths, *replacements, *idmap, child, left_parent, + right_parent; + double breakpoint; + int ret, j, t; + edge_buffer new_births; + size_t ndeaths; + tsk_modular_simplifier_t simplifier; + + assert(simplify_interval != 0); // leads to division by zero + assert(pdeath > 0.0 && pdeath <= 1.0); + ret = edge_buffer_init(&new_births, 1000); + assert(ret == 0); + buffer = malloc(2 * N * sizeof(tsk_id_t)); + if (buffer == NULL) { + errx(EXIT_FAILURE, "Out of memory"); + } + idmap = malloc(N * sizeof(tsk_id_t)); + if (idmap == NULL) { + errx(EXIT_FAILURE, "Out of memory"); + } + deaths = malloc(N * sizeof(tsk_id_t)); + if (deaths == NULL) { + errx(EXIT_FAILURE, "Out of memory"); + } + tables->sequence_length = 1.0; + alive = buffer; + replacements = buffer + N; + for (j = 0; j < N; j++) { + alive[j] + = tsk_node_table_add_row(&tables->nodes, 0, T, TSK_NULL, TSK_NULL, NULL, 0); + check_tsk_error(alive[j]); + } + for (t = T - 1; t >= 0; t--) { + ndeaths = 0; + for (j = 0; j < N; j++) { + /* NOTE: the use of rand() is discouraged for + * research code and proper random number generator + * libraries should be preferred. + */ + if (rand() / (1. + RAND_MAX) <= pdeath) { + deaths[ndeaths] = j; + ++ndeaths; + } + } + for (j = 0; j < ndeaths; j++) { + child = tsk_node_table_add_row( + &tables->nodes, 0, t, TSK_NULL, TSK_NULL, NULL, 0); + check_tsk_error(child); + /* NOTE: the use of rand() is discouraged for + * research code and proper random number generator + * libraries should be preferred. + */ + left_parent = alive[(size_t)((rand() / (1. + RAND_MAX)) * N)]; + right_parent = alive[(size_t)((rand() / (1. + RAND_MAX)) * N)]; + do { + breakpoint = rand() / (1. + RAND_MAX); + } while (breakpoint == 0); /* tiny proba of breakpoint being 0 */ + /* NOTE: invalid left/right values here CANNOT be caught + * by tsk_table_collection_check_integrity! + * (They are not present in the edge table when the + * simplifier is initialized, and the simplified tables + * will naively process the overlaps, resulting in invalid output.) + * + * It is therefore a precondition that input edges are okay. + */ + edge_buffer_buffer_birth(0., breakpoint, tables->nodes.time[left_parent], + left_parent, child, &new_births); + edge_buffer_buffer_birth(breakpoint, 1.0, tables->nodes.time[right_parent], + right_parent, child, &new_births); + replacements[j] = child; + } + /* replace deaths with births */ + for (j = 0; j < ndeaths; j++) { + alive[deaths[j]] = replacements[j]; + } + if (t % simplify_interval == 0) { + printf("Simplify at generation %lld: (%lld nodes %lld edges)", (long long) t, + (long long) tables->nodes.num_rows, (long long) tables->edges.num_rows); + edge_buffer_prep_for_simplification(&new_births); + idmap + = (tsk_id_t *) realloc(idmap, tables->nodes.num_rows * sizeof(tsk_id_t)); + ret = tsk_modular_simplifier_init(&simplifier, tables, alive, N, 0); + check_tsk_error(ret); + j = 0; + while (j < new_births.size) { + left_parent = new_births.births[j].parent; + while ( + j < new_births.size && new_births.births[j].parent == left_parent) { + ret = tsk_modular_simplifier_add_edge(&simplifier, + new_births.births[j].left, new_births.births[j].right, + new_births.births[j].parent, new_births.births[j].child); + check_tsk_error(ret); + j++; + } + ret = tsk_modular_simplifier_merge_ancestors(&simplifier, left_parent); + check_tsk_error(ret); + } + ret = tsk_modular_simplifier_finalise(&simplifier, idmap); + check_tsk_error(ret); + ret = tsk_modular_simplifier_free(&simplifier); + check_tsk_error(ret); + /* For fun/safety/paranoia */ + ret = tsk_table_collection_check_integrity(tables, TSK_CHECK_EDGE_ORDERING); + check_tsk_error(ret); + printf(" -> (%lld nodes %lld edges)\n", (long long) tables->nodes.num_rows, + (long long) tables->edges.num_rows); + for (j = 0; j < N; j++) { + alive[j] = idmap[alive[j]]; + assert(alive[j] != TSK_NULL); + } + /* The order of these next two steps MATTERS */ + edge_buffer_clear(&new_births); + edge_buffer_post_simplification(alive, N, tables, &new_births); + } + } + free(buffer); + free(idmap); + free(deaths); + edge_buffer_free(&new_births); +} + +int +main(int argc, char **argv) +{ + int ret; + tsk_table_collection_t tables; + + if (argc != 7) { + errx(EXIT_FAILURE, "usage: N T simplify-interval output-file seed pdeath"); + } + ret = tsk_table_collection_init(&tables, 0); + check_tsk_error(ret); + srand((unsigned) atoi(argv[5])); + simulate(&tables, atoi(argv[1]), atoi(argv[2]), atoi(argv[3]), atof(argv[6])); + ret = tsk_table_collection_dump(&tables, argv[4], 0); + check_tsk_error(ret); + + tsk_table_collection_free(&tables); + return 0; +} diff --git a/c/meson.build b/c/meson.build index c6150db2e7..39ed8d70c6 100644 --- a/c/meson.build +++ b/c/meson.build @@ -117,5 +117,8 @@ if not meson.is_subproject() executable('haploid_wright_fisher', sources: ['examples/haploid_wright_fisher.c'], link_with: [tskit_lib], dependencies: lib_deps) + executable('efficient_forward_simulation', + sources: ['examples/efficient_forward_simulation.c'], + link_with: [tskit_lib], dependencies: lib_deps) endif endif diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 6de6675ff6..b81cbd8032 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -24,6 +24,9 @@ #include "testlib.h" #include "tskit/core.h" +#include "tskit/trees.h" +#include +#include #include #include @@ -8629,6 +8632,613 @@ test_simplify_metadata(void) tsk_table_collection_free(&tables); } +typedef struct { + tsk_id_t **parent; + tsk_id_t **child; + double **left; + double **right; + tsk_size_t max_nodes; + tsk_size_t *num_buffered_edges; +} fauxbuffer; + +static void +fauxbuffer_init(tsk_size_t max_nodes, fauxbuffer *buffer) +{ + tsk_size_t i; + buffer->parent = tsk_malloc(max_nodes * sizeof(tsk_id_t *)); + CU_ASSERT_FATAL(buffer->parent != NULL); + buffer->child = tsk_malloc(max_nodes * sizeof(tsk_id_t *)); + CU_ASSERT_FATAL(buffer->child != NULL); + buffer->left = tsk_malloc(max_nodes * sizeof(double *)); + CU_ASSERT_FATAL(buffer->parent != NULL); + buffer->right = tsk_malloc(max_nodes * sizeof(double *)); + CU_ASSERT_FATAL(buffer->right != NULL); + + buffer->num_buffered_edges = tsk_malloc(max_nodes * sizeof(tsk_size_t)); + CU_ASSERT_FATAL(buffer->num_buffered_edges != NULL); + + for (i = 0; i < max_nodes; ++i) { + buffer->parent[i] = NULL; + buffer->child[i] = NULL; + buffer->left[i] = NULL; + buffer->right[i] = NULL; + buffer->num_buffered_edges[i] = 0; + } + buffer->max_nodes = max_nodes; +} + +static void +fauxbuffer_buffer( + tsk_id_t parent, tsk_id_t child, double left, double right, fauxbuffer *buffer) +{ + CU_ASSERT_FATAL(parent < (tsk_id_t) buffer->max_nodes); + buffer->num_buffered_edges[parent] += 1; + buffer->parent[parent] = tsk_realloc( + buffer->parent[parent], buffer->num_buffered_edges[parent] * sizeof(tsk_id_t)); + buffer->child[parent] = tsk_realloc( + buffer->child[parent], buffer->num_buffered_edges[parent] * sizeof(tsk_id_t)); + buffer->left[parent] = tsk_realloc( + buffer->left[parent], buffer->num_buffered_edges[parent] * sizeof(double)); + buffer->right[parent] = tsk_realloc( + buffer->right[parent], buffer->num_buffered_edges[parent] * sizeof(double)); + buffer->parent[parent][buffer->num_buffered_edges[parent] - 1] = parent; + buffer->child[parent][buffer->num_buffered_edges[parent] - 1] = child; + buffer->left[parent][buffer->num_buffered_edges[parent] - 1] = left; + buffer->right[parent][buffer->num_buffered_edges[parent] - 1] = right; +} + +static void +fauxbuffer_free(fauxbuffer *buffer) +{ + tsk_size_t i; + + for (i = 0; i < buffer->max_nodes; ++i) { + tsk_safe_free(buffer->parent[i]); + tsk_safe_free(buffer->child[i]); + tsk_safe_free(buffer->left[i]); + tsk_safe_free(buffer->right[i]); + } + tsk_safe_free(buffer->parent); + tsk_safe_free(buffer->child); + tsk_safe_free(buffer->left); + tsk_safe_free(buffer->right); + tsk_safe_free(buffer->num_buffered_edges); +} + +/* + * Start with this tree: + * 6 + * / \ + * / \ + * / \ + * / 5 + * 4 / \ + * / \ / \ + * 0 1 2 3 + * + * Add data like a fake forward sim to give: + * + * 6 + * / \ + * / \ + * / \ + * / 5 + * 4 / \ + * / \ / \ + * 0 1 2 3 + * | | + * 7 8--- <- new_parent 1 and 2, resp. + * | | | + * 9 10 11 <- new child 1, 2, and 3, resp. + * + * Then, we simplify w.r.to [9, 10, 11]. + */ +static void +make_single_tree_for_testing_modular_simplify( + tsk_table_collection_t *tables, tsk_edge_table_t *new_edges, tsk_id_t **samples) +{ + int ret; + tsk_id_t new_parent1, new_parent2, new_child1, new_child2, new_child3; + tsk_size_t row; + ret = tsk_table_collection_init(tables, 0); + tables->sequence_length = 1.0; + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_init(new_edges, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + *samples = tsk_malloc(3 * sizeof(tsk_id_t)); + CU_ASSERT_TRUE(samples != NULL) + + parse_nodes(single_tree_ex_nodes, &tables->nodes); + parse_edges(single_tree_ex_edges, &tables->edges); + + /* "record new births" into node table */ + new_parent1 = tsk_node_table_add_row(&tables->nodes, 0, -1.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_parent1 >= 0); + new_parent2 = tsk_node_table_add_row(&tables->nodes, 0, -1.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_parent2 >= 0); + new_child1 = tsk_node_table_add_row(&tables->nodes, 0, -2.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_child1 >= 0); + new_child2 = tsk_node_table_add_row(&tables->nodes, 0, -2.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_child2 >= 0); + new_child3 = tsk_node_table_add_row(&tables->nodes, 0, -2.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_child3 >= 0); + + for (row = 0; row < tables->nodes.num_rows; ++row) { + // make all times >= 0.0. + tables->nodes.time[row] += 2.0; + } + + /* record edges */ + ret = (int) tsk_edge_table_add_row( + new_edges, 0, tables->sequence_length, 0, new_parent1, NULL, 0); + CU_ASSERT_TRUE_FATAL(ret >= 0); + ret = (int) tsk_edge_table_add_row( + new_edges, 0, tables->sequence_length, 2, new_parent2, NULL, 0); + CU_ASSERT_TRUE_FATAL(ret >= 0); + ret = (int) tsk_edge_table_add_row( + new_edges, 0, tables->sequence_length, new_parent1, new_child1, NULL, 0); + CU_ASSERT_TRUE_FATAL(ret >= 0); + ret = (int) tsk_edge_table_add_row( + new_edges, 0, tables->sequence_length, new_parent2, new_child2, NULL, 0); + CU_ASSERT_TRUE_FATAL(ret >= 0); + ret = (int) tsk_edge_table_add_row( + new_edges, 0, tables->sequence_length, new_parent2, new_child3, NULL, 0); + CU_ASSERT_TRUE_FATAL(ret >= 0); + (*samples)[0] = new_child1; + (*samples)[1] = new_child2; + (*samples)[2] = new_child3; +} + +/* This is our starting tree. + * We will add additional births + * to 0/1/3 and then use that as + * the basis for testing. + * + * Alive nodes will be 0, 1, 2, 3, 4 + * in order to generated the complexity + * we need + * + * 1.20┊ ┊ 8 ┊ ┊ + * ┊ ┊ ┏━┻━┓ ┊ ┊ + * 1.00┊ 7 ┊ ┃ ┃ ┊ ┊ + * ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┊ + * 0.70┊ ┃ ┃ ┊ ┃ ┃ ┊ 6 ┊ + * ┊ ┃ ┃ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ + * 0.50┊ ┃ 5 ┊ 5 ┃ ┊ ┃ 5 ┊ + * ┊ ┃ ┏━┻┓ ┊ ┏┻━┓ ┃ ┊ ┃ ┏━┻┓ ┊ + * 0.40┊ ┃ ┃ 4 ┊ 4 ┃ ┃ ┊ ┃ ┃ 4 ┊ + * ┊ ┃ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + * 0.20┊ ┃ ┃ ┃ 3 ┊ ┃ ┃ ┃ 3 ┊ ┃ ┃ ┃ 3 ┊ + * ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ + * 0.10┊ ┃ 1 2 ┊ ┃ 2 1 ┊ ┃ 1 2 ┊ + * ┊ ┃ ┊ ┃ ┊ ┃ ┊ + * 0.00┊ 0 ┊ 0 ┊ 0 ┊ + * 0.00 2.00 8.00 10.00 + */ +static void +make_overlapping_generations_trees_for_testing_modular_simplify( + tsk_table_collection_t *tables, tsk_edge_table_t *new_edges, tsk_id_t **samples) +{ + int ret; + tsk_id_t new_child0, new_child1, new_child2; + tsk_size_t row, moved, edge; + int last_row_to_lift_over; + fauxbuffer buffer; + ret = tsk_table_collection_init(tables, 0); + double tmax; + CU_ASSERT_EQUAL_FATAL(ret, 0); + parse_edges(internal_sample_ex_edges, &tables->edges); + parse_nodes(internal_sample_ex_nodes, &tables->nodes); + tables->sequence_length = 10.0; + fauxbuffer_init(12, &buffer); + ret = tsk_edge_table_init(new_edges, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + *samples = tsk_malloc(3 * sizeof(tsk_id_t)); + CU_ASSERT_TRUE_FATAL(*samples != NULL); + + new_child0 = tsk_node_table_add_row(&tables->nodes, 0, -1.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_child0 > 0); + new_child1 = tsk_node_table_add_row(&tables->nodes, 0, -1.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_child1 > 0); + new_child2 = tsk_node_table_add_row(&tables->nodes, 0, -1.0, -1, -1, NULL, 0); + CU_ASSERT_TRUE_FATAL(new_child2 > 0); + + for (row = 0; row < tables->nodes.num_rows; ++row) { + // for some reason, the fixture sets pop to 0... + tables->nodes.population[row] = -1; + // make all times >= 0.0. + tables->nodes.time[row] += 1.0; + } + ret = (int) tsk_table_collection_check_integrity(tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tmax = tables->nodes.time[4]; + + /* We need to cheat, "stealing" edges from the input + * table and putting them into the new_edges so that + * we can mimic what a "real" implementation needs + * to do. + * + * Initial edge fixture: + * const char *internal_sample_ex_edges = "2 8 4 0\n" + * "0 10 4 2\n" + * "0 2 4 3\n" + * "8 10 4 3\n" + * "0 10 5 1,4\n" + * "8 10 6 0,5\n" + * "0 2 7 0,5\n" + * "2 8 8 3,5\n"; + * We can rewrite it to make the edge_id easier to read: + * "2 8 4 0\n" + * "0 10 4 2\n" + * "0 2 4 3\n" + * "8 10 4 3\n" + * "0 10 5 1\n" + * "0 10 5 4\n" + * "8 10 6 0\n" + * "8 10 6 5\n" + * "0 2 7 0\n" + * "0 2 7 5\n" + * "2 8 8 3\n"; + * "2 8 8 5\n"; + */ + last_row_to_lift_over = -1; + for (row = 0; row < tables->edges.num_rows; ++row) { + if (tables->nodes.time[tables->edges.parent[row]] <= tmax) { + last_row_to_lift_over = (int) row; + } + } + moved = 0; + if (last_row_to_lift_over > -1) { + for (row = 0; row < (tsk_size_t) last_row_to_lift_over; ++row) { + fauxbuffer_buffer( + tables->edges.parent[(tsk_size_t) last_row_to_lift_over - row], + tables->edges.child[(tsk_size_t) last_row_to_lift_over - row], + tables->edges.left[(tsk_size_t) last_row_to_lift_over - row], + tables->edges.right[(tsk_size_t) last_row_to_lift_over - row], &buffer); + } + for (row = (tsk_size_t) last_row_to_lift_over + 1; row < tables->edges.num_rows; + ++row) { + tables->edges.left[moved] = tables->edges.left[row]; + tables->edges.right[moved] = tables->edges.right[row]; + tables->edges.parent[moved] = tables->edges.parent[row]; + tables->edges.child[moved] = tables->edges.child[row]; + moved += 1; + } + } + tsk_edge_table_truncate(&tables->edges, moved); + /* To maintain sanity, transmit non-recombinant genomes */ + fauxbuffer_buffer(0, new_child0, 0., tables->sequence_length, &buffer); + fauxbuffer_buffer(1, new_child1, 0., tables->sequence_length, &buffer); + fauxbuffer_buffer(3, new_child2, 0., tables->sequence_length, &buffer); + + for (row = 0; row < buffer.max_nodes; ++row) { + for (edge = 0; edge < buffer.num_buffered_edges[row]; ++edge) { + tsk_edge_table_add_row(new_edges, buffer.left[row][edge], + buffer.right[row][edge], buffer.parent[row][edge], + buffer.child[row][edge], NULL, 0); + } + } + + (*samples)[0] = new_child0; + (*samples)[1] = new_child1; + (*samples)[2] = new_child2; + + fauxbuffer_free(&buffer); +} + +static int +run_test_modular_simplifier(tsk_table_collection_t *tables, tsk_edge_table_t *new_edges, + tsk_id_t *row_order, tsk_size_t len_row_order, tsk_id_t *samples, + tsk_size_t len_samples) +{ + int ret; + tsk_table_collection_t standard_tables; + tsk_modular_simplifier_t simplifier; + tsk_treeseq_t treeseq, standard_treeseq; + tsk_tree_t standard_tree, tree; + tsk_id_t last_parent; + tsk_size_t row, row_for_parent; + double ttl_time, standard_ttl_time; + + ret = tsk_table_collection_copy(tables, &standard_tables, 0); + if (ret < 0) { + goto out; + } + ret = tsk_edge_table_append_columns(&standard_tables.edges, new_edges->num_rows, + new_edges->left, new_edges->right, new_edges->parent, new_edges->child, NULL, + NULL); + if (ret < 0) { + goto out; + } + ret = tsk_table_collection_sort(&standard_tables, NULL, 0); + if (ret < 0) { + goto out; + } + CU_ASSERT_EQUAL_FATAL( + standard_tables.edges.num_rows, tables->edges.num_rows + new_edges->num_rows); + CU_ASSERT_EQUAL_FATAL(standard_tables.nodes.num_rows, tables->nodes.num_rows); + + ret = tsk_table_collection_simplify(&standard_tables, samples, 3, 0, NULL); + if (ret < 0) { + goto out; + } + ret = tsk_modular_simplifier_init(&simplifier, tables, samples, len_samples, 0); + if (ret < 0) { + goto out; + } + /* Pseudocode that we are mocking: + * For each parent of a new edge: + * - add that edge to the segment queue. + * - When done, finalise the queue and merge ancestors. + * + * If our buffer is wrong, we will have parents unsorted by time + * and/or the same parent processed in different loop iterations. + * Each case is an error that MUST be handled. + * It is trivial to show that not handling the errors can give rise + * to invalid table collections / tree sequences. + * (TODO: The requirement for error handling must be documented + * in tables.h.) + * + * Production code should use an input other than + * an edge table. + * (How edges are sorted is an internal detail + * and cannot be used for testing.) + */ + for (row = 0; row < len_row_order; ++row) { + last_parent = new_edges->parent[row_order[row]]; + row_for_parent = 0; + for (row_for_parent = 0; + (tsk_size_t) row_order[row] + row_for_parent < new_edges->num_rows + && new_edges->parent[(tsk_size_t) row_order[row] + row_for_parent] + == last_parent; + ++row_for_parent) { + CU_ASSERT_FATAL( + (tsk_size_t) row_order[row] + row_for_parent < new_edges->num_rows); + ret = tsk_modular_simplifier_add_edge(&simplifier, + new_edges->left[(tsk_size_t) row_order[row] + row_for_parent], + new_edges->right[(tsk_size_t) row_order[row] + row_for_parent], + new_edges->parent[(tsk_size_t) row_order[row] + row_for_parent], + new_edges->child[(tsk_size_t) row_order[row] + row_for_parent]); + if (ret < 0) { + goto out; + } + } + ret = tsk_modular_simplifier_merge_ancestors(&simplifier, last_parent); + if (ret < 0) { + goto out; + } + } + /* Simplification's internal cleanup. + * Should NOT be called if above loop errors. + * We know that not calling it and calling + * "modular simplifier free" does not leak + * because valgrind is happy. + * + * Now, we have processed all (child) nodes whose births are + * MORE RECENT than those in the input tables. + */ + ret = tsk_modular_simplifier_finalise(&simplifier, NULL); + if (ret < 0) { + goto out; + } + + // Now, we can compare various properties of the two table collections + CU_ASSERT_EQUAL_FATAL(standard_tables.edges.num_rows, tables->edges.num_rows); + CU_ASSERT_EQUAL_FATAL(standard_tables.nodes.num_rows, tables->nodes.num_rows); + + ret = tsk_table_collection_build_index(&standard_tables, 0); + if (ret < 0) { + goto out; + } + ret = tsk_table_collection_build_index(tables, 0); + if (ret < 0) { + goto out; + } + + ret = tsk_treeseq_init(&standard_treeseq, &standard_tables, 0); + if (ret < 0) { + goto out; + } + ret = tsk_treeseq_init(&treeseq, tables, 0); + if (ret < 0) { + goto out; + } + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_trees(&standard_treeseq), + tsk_treeseq_get_num_trees(&treeseq)); + ret = tsk_tree_init(&standard_tree, &standard_treeseq, 0); + if (ret < 0) { + goto out; + } + ret = tsk_tree_init(&tree, &treeseq, 0); + if (ret < 0) { + goto out; + } + ret = tsk_tree_first(&standard_tree); + CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); + for (ret = tsk_tree_first(&tree); ret == TSK_TREE_OK; ret = tsk_tree_next(&tree)) { + tsk_tree_get_total_branch_length(&tree, -1, &ttl_time); + tsk_tree_get_total_branch_length(&standard_tree, -1, &standard_ttl_time); + CU_ASSERT_TRUE(ttl_time - standard_ttl_time <= 1e-9); + tsk_tree_next(&standard_tree); + } + tsk_treeseq_free(&standard_treeseq); + tsk_treeseq_free(&treeseq); + tsk_tree_free(&standard_tree); + tsk_tree_free(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 0); +out: + tsk_table_collection_free(&standard_tables); + tsk_safe_free(samples); + tsk_modular_simplifier_free(&simplifier); + return ret; +} + +static void +run_test_modular_simplify_single_tree( + tsk_id_t *row_order, tsk_size_t len_row_order, int expected_result) +{ + int ret; + tsk_table_collection_t tables; + tsk_edge_table_t new_edges; + tsk_id_t *samples; + make_single_tree_for_testing_modular_simplify(&tables, &new_edges, &samples); + ret = run_test_modular_simplifier( + &tables, &new_edges, row_order, len_row_order, samples, 3); + tsk_table_collection_free(&tables); + tsk_edge_table_free(&new_edges); + CU_ASSERT_EQUAL_FATAL(ret, expected_result); +} + +static void +test_table_collection_modular_simplify_simple_tree(void) +{ + tsk_id_t row_order[4] = { 3, 2, 1, 0 }; + run_test_modular_simplify_single_tree(&row_order[0], 4, 0); +} + +static void +test_table_collection_modular_simplify_simple_tree_discontiguous_parents(void) +{ + tsk_id_t row_order[4] = { 3, 2, 3, 1 }; + run_test_modular_simplify_single_tree( + &row_order[0], 4, TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS); +} + +static void +test_table_collection_modular_simplify_simple_tree_add_edges_wrong_birth_order(void) +{ + tsk_id_t row_order[4] = { 0, 1, 2, 3 }; + run_test_modular_simplify_single_tree( + &row_order[0], 4, TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME); +} + +static void +test_table_collection_modular_simplify_add_invalid_parent_or_child(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_edge_table_t new_edges; + tsk_modular_simplifier_t simplifier; + tsk_id_t *samples; + make_single_tree_for_testing_modular_simplify(&tables, &new_edges, &samples); + ret = tsk_modular_simplifier_init(&simplifier, &tables, samples, 3, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_modular_simplifier_add_edge( + &simplifier, 0., 1, TSK_NULL, new_edges.child[4]); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NULL_PARENT); + ret = tsk_modular_simplifier_add_edge( + &simplifier, 0., 1, new_edges.parent[4], TSK_NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NULL_CHILD); + ret = tsk_modular_simplifier_add_edge(&simplifier, 0., 1, 10000, new_edges.child[4]); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_modular_simplifier_add_edge( + &simplifier, 0., 1, new_edges.parent[4], 10000); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_safe_free(samples); + tsk_table_collection_free(&tables); + tsk_edge_table_free(&new_edges); + tsk_modular_simplifier_free(&simplifier); +} + +static void +test_table_collection_modular_simplify_add_child_with_invalid_time(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_edge_table_t new_edges; + tsk_modular_simplifier_t simplifier; + tsk_id_t *samples; + make_single_tree_for_testing_modular_simplify(&tables, &new_edges, &samples); + /* edit the first child's birth time to be "very wrong" */ + tables.nodes.time[new_edges.child[4]] = DBL_MAX; + ret = tsk_modular_simplifier_init(&simplifier, &tables, samples, 3, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_modular_simplifier_add_edge( + &simplifier, 0., 1, new_edges.parent[4], new_edges.child[4]); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_TIME_ORDERING); + + tsk_safe_free(samples); + tsk_table_collection_free(&tables); + tsk_edge_table_free(&new_edges); + tsk_modular_simplifier_free(&simplifier); +} + +static void +run_test_modular_simplify_overlapping_generations( + tsk_id_t *row_order, tsk_size_t len_row_order, int expected_result) +{ + int ret; + tsk_table_collection_t tables; + tsk_edge_table_t new_edges; + tsk_id_t *samples; + make_overlapping_generations_trees_for_testing_modular_simplify( + &tables, &new_edges, &samples); + ret = run_test_modular_simplifier( + &tables, &new_edges, row_order, len_row_order, samples, 3); + tsk_table_collection_free(&tables); + tsk_edge_table_free(&new_edges); + CU_ASSERT_EQUAL_FATAL(ret, expected_result); +} + +static void +test_table_collection_modular_simplify_overlapping_generations(void) +{ + tsk_id_t row_order[4] = { 0, 1, 2, 3 }; + run_test_modular_simplify_overlapping_generations(&row_order[0], 4, 0); +} + +static void +test_table_collection_modular_simplify_overlapping_generations_parent_time_error(void) +{ + tsk_id_t row_order[4] = { 2, 3, 0, 1 }; + run_test_modular_simplify_overlapping_generations( + &row_order[0], 4, TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME); +} + +/* This hits part of simplifier intialisation that is + * NOT part of table integrity checks + */ +static void +test_table_collection_modular_simplify_bad_samples(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_edge_table_t new_edges; + tsk_modular_simplifier_t simplifier; + tsk_id_t *samples; + make_single_tree_for_testing_modular_simplify(&tables, &new_edges, &samples); + samples[0] = -1; + ret = tsk_modular_simplifier_init(&simplifier, &tables, samples, 3, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_safe_free(samples); + tsk_table_collection_free(&tables); + tsk_edge_table_free(&new_edges); + tsk_modular_simplifier_free(&simplifier); +} + +static void +test_table_collection_modular_simplify_table_integrity_check_fail(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_edge_table_t new_edges; + tsk_modular_simplifier_t simplifier; + tsk_id_t *samples; + double temp; + make_single_tree_for_testing_modular_simplify(&tables, &new_edges, &samples); + temp = tables.nodes.time[4]; + // now we have a parent/child time violation + tables.nodes.time[4] = tables.nodes.time[3]; + tables.nodes.time[3] = temp; + ret = tsk_modular_simplifier_init(&simplifier, &tables, samples, 3, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_TIME_ORDERING); + tsk_safe_free(samples); + tsk_table_collection_free(&tables); + tsk_edge_table_free(&new_edges); + tsk_modular_simplifier_free(&simplifier); +} + static void test_edge_update_invalidates_index(void) { @@ -8791,7 +9401,8 @@ test_sort_tables_offsets(void) CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tables, ©, 0)); - /* Check that sorting would have had no effect as individuals not in default sort*/ + /* Check that sorting would have had no effect as individuals not in default + * sort*/ ret = tsk_table_collection_sort(&tables, NULL, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tables, ©, 0)); @@ -11592,6 +12203,27 @@ main(int argc, char **argv) { "test_table_collection_equals_options", test_table_collection_equals_options }, { "test_table_collection_simplify_errors", test_table_collection_simplify_errors }, + { "test_table_collection_modular_simplify_simple_tree", + test_table_collection_modular_simplify_simple_tree }, + { "test_table_collection_modular_simplify_simple_tree_discontiguous_parents", + test_table_collection_modular_simplify_simple_tree_discontiguous_parents }, + { "test_table_collection_modular_simplify_simple_tree_add_edges_wrong_birth_" + "order", + test_table_collection_modular_simplify_simple_tree_add_edges_wrong_birth_order }, + { "test_table_collection_modular_simplify_add_invalid_parent_or_child", + test_table_collection_modular_simplify_add_invalid_parent_or_child }, + { "test_table_collection_modular_simplify_add_child_with_invalid_time", + test_table_collection_modular_simplify_add_child_with_invalid_time }, + { "test_table_collection_modular_simplify_bad_samples", + test_table_collection_modular_simplify_bad_samples }, + { "test_table_collection_modular_simplify_table_integrity_check_fail", + test_table_collection_modular_simplify_table_integrity_check_fail }, + { "test_table_collection_modular_simplify_overlapping_generations", + test_table_collection_modular_simplify_overlapping_generations }, + { "test_table_collection_modular_simplify_overlapping_generations_parent_" + "time_" + "error", + test_table_collection_modular_simplify_overlapping_generations_parent_time_error }, { "test_table_collection_time_units", test_table_collection_time_units }, { "test_table_collection_reference_sequence", test_table_collection_reference_sequence }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 8eea85f5ad..c9f237e781 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -23,6 +23,7 @@ * SOFTWARE. */ +#include "tskit/core.h" #include #include #include @@ -13607,3 +13608,149 @@ tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, self->tree_left = right; return ret; } + +typedef struct __tsk_modular_simplifier_impl_t { + simplifier_t simplifier; + tsk_id_t last_parent_processed; + tsk_id_t *input_node_visited; + tsk_size_t num_input_nodes; + double last_parent_time; + /*double minimum_input_node_time;*/ +} tsk_modular_simplifier_impl_t; + +int +tsk_modular_simplifier_init(tsk_modular_simplifier_t *self, + tsk_table_collection_t *tables, const tsk_id_t *samples, tsk_size_t num_samples, + tsk_flags_t options) +{ + int ret = 0; + tsk_size_t i; + self->pimpl = tsk_malloc(sizeof(tsk_modular_simplifier_impl_t)); + if (self->pimpl == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* Have to init this array here b/c the internal simplifier + * "steals" input tables" + */ + self->pimpl->input_node_visited + = tsk_malloc(tables->nodes.num_rows * sizeof(tsk_id_t)); + if (self->pimpl->input_node_visited == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + for (i = 0; i < tables->nodes.num_rows; ++i) { + self->pimpl->input_node_visited[i] = 0; + } + self->pimpl->num_input_nodes = tables->nodes.num_rows; + self->pimpl->last_parent_processed = -1; + self->pimpl->last_parent_time = DBL_MIN; + + /* NOTE: the original intent here was to catch + * issues where children aren't sorted properly, + * but it is clear I didn't get that right. + * + * Need to write more tests to see if I can + * trigger any problems + */ + /* + self->pimpl->minimum_input_node_time = DBL_MAX; + for (i = 0; i < tables->edges.num_rows; ++i) { + self->pimpl->minimum_input_node_time + = TSK_MIN(self->pimpl->minimum_input_node_time, + tables->nodes.time[tables->edges.child[i]]); + } + */ + + /* Now that we have set up the pimpl state, + * we can let the unsual init happen + */ + ret = simplifier_init( + &self->pimpl->simplifier, samples, num_samples, tables, options); +out: + return ret; +} + +int +tsk_modular_simplifier_add_edge(tsk_modular_simplifier_t *self, double left, + double right, tsk_id_t parent, tsk_id_t child) +{ + int ret = 0; + + if (parent == TSK_NULL) { + ret = TSK_ERR_NULL_PARENT; + goto out; + } + if (child == TSK_NULL) { + ret = TSK_ERR_NULL_CHILD; + goto out; + } + if (parent >= (tsk_id_t) self->pimpl->num_input_nodes + || child >= (tsk_id_t) self->pimpl->num_input_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + + if (self->pimpl->simplifier.input_tables.nodes.time[child] + >= self->pimpl->simplifier.input_tables.nodes.time[parent]) { + ret = TSK_ERR_BAD_NODE_TIME_ORDERING; + goto out; + } + + if (self->pimpl->simplifier.input_tables.nodes.time[parent] + < self->pimpl->last_parent_time) { + ret = TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME; + goto out; + } + self->pimpl->last_parent_time + = self->pimpl->simplifier.input_tables.nodes.time[parent]; + + if (parent != self->pimpl->last_parent_processed + && self->pimpl->last_parent_processed != TSK_NULL) { + if (self->pimpl->input_node_visited[parent] != 0) { + ret = TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS; + goto out; + } + } + ret = simplifier_extract_ancestry(&self->pimpl->simplifier, left, right, child); +out: + return ret; +} + +int +tsk_modular_simplifier_merge_ancestors(tsk_modular_simplifier_t *self, tsk_id_t parent) +{ + int ret = simplifier_merge_ancestors(&self->pimpl->simplifier, parent); + if (ret != 0) { + goto out; + } + /* mark this input parent as "seen" */ + self->pimpl->input_node_visited[parent] = 1; + self->pimpl->simplifier.segment_queue_size = 0; + self->pimpl->last_parent_processed = parent; +out: + return ret; +} + +int +tsk_modular_simplifier_finalise(tsk_modular_simplifier_t *self, tsk_id_t *node_map) +{ + int ret = 0; + simplifier_t *simplifier = &self->pimpl->simplifier; + ret = simplifier_run(simplifier, node_map); + return ret; +} + +int +tsk_modular_simplifier_free(tsk_modular_simplifier_t *self) +{ + int ret = 0; + ret = simplifier_free(&self->pimpl->simplifier); + if (ret != 0) { + goto out; + } + tsk_safe_free(self->pimpl->input_node_visited); + tsk_safe_free(self->pimpl); +out: + return ret; +} diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 38f3096c9d..0fb6b64d67 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -4784,6 +4784,30 @@ int tsk_diff_iter_next(tsk_diff_iter_t *self, double *left, double *right, tsk_edge_list_t *edges_out, tsk_edge_list_t *edges_in); void tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out); +/* TODO: document */ +typedef struct { + /* don't leak private types into public API */ + struct __tsk_modular_simplifier_impl_t *pimpl; +} tsk_modular_simplifier_t; + +/* TODO: document */ +int tsk_modular_simplifier_init(tsk_modular_simplifier_t *self, + tsk_table_collection_t *tables, const tsk_id_t *samples, tsk_size_t num_samples, + tsk_flags_t options); +/* TODO: document */ +int tsk_modular_simplifier_free(tsk_modular_simplifier_t *self); +/* TODO: document */ +int tsk_modular_simplifier_add_edge(tsk_modular_simplifier_t *self, double left, + double right, tsk_id_t parent, tsk_id_t child); +/* TODO: document */ +int tsk_modular_simplifier_merge_ancestors( + tsk_modular_simplifier_t *self, tsk_id_t parent); + +/* TODO: document */ +// runs the simplifier, thus processing ancient edges +// present in the input edge table. +int tsk_modular_simplifier_finalise(tsk_modular_simplifier_t *self, tsk_id_t *node_map); + #ifdef __cplusplus } #endif diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 02e0482cca..c009ee8fa4 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2022 Tskit Developers +# Copyright (c) 2019-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -114,6 +114,8 @@ def __init__( filter_nodes=True, update_sample_flags=True, ): + # DELETE ME + self.parent_edges_processed = 0 self.ts = ts self.n = len(sample) self.reduce_to_site_topology = reduce_to_site_topology @@ -397,6 +399,7 @@ def process_parent_edges(self, edges): """ Process all of the edges for a given parent. """ + self.parent_edges_processed += len(edges) assert len({e.parent for e in edges}) == 1 parent = edges[0].parent S = [] @@ -535,6 +538,14 @@ def insert_input_roots(self): offset += 1 self.sort_offset = offset + def finalise(self): + if self.keep_input_roots: + self.insert_input_roots() + self.finalise_sites() + self.finalise_references() + if self.sort_offset != -1: + self.tables.sort(edge_start=self.sort_offset) + def simplify(self): if self.ts.num_edges > 0: all_edges = list(self.ts.edges()) @@ -545,12 +556,7 @@ def simplify(self): edges = [] edges.append(e) self.process_parent_edges(edges) - if self.keep_input_roots: - self.insert_input_roots() - self.finalise_sites() - self.finalise_references() - if self.sort_offset != -1: - self.tables.sort(edge_start=self.sort_offset) + self.finalise() ts = self.tables.tree_sequence() return ts, self.node_id_map diff --git a/python/tests/test_forward_sims.py b/python/tests/test_forward_sims.py new file mode 100644 index 0000000000..54ea9039fd --- /dev/null +++ b/python/tests/test_forward_sims.py @@ -0,0 +1,259 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Python implementation of the low-level supporting code for forward simulations. +""" +import itertools +import random + +import numpy as np +import pytest + +import tskit +from tests import simplify + + +class BirthBuffer: + def __init__(self): + self.edges = {} + self.parents = [] + + def add_edge(self, left, right, parent, child): + if parent not in self.edges: + self.parents.append(parent) + self.edges[parent] = [] + self.edges[parent].append((child, left, right)) + + def clear(self): + self.edges = {} + self.parents = [] + + def __str__(self): + s = "" + for parent in self.parents: + for child, left, right in self.edges[parent]: + s += f"{parent}\t{child}\t{left:0.3f}\t{right:0.3f}\n" + return s + + +def add_younger_edges_to_simplifier(simplifier, t, tables, edge_offset): + parent_edges = [] + while ( + edge_offset < len(tables.edges) + and tables.nodes.time[tables.edges.parent[edge_offset]] <= t + ): + print("edge offset = ", edge_offset) + if len(parent_edges) == 0: + last_parent = tables.edges.parent[edge_offset] + else: + last_parent = parent_edges[-1].parent + if last_parent == tables.edges.parent[edge_offset]: + parent_edges.append(tables.edges[edge_offset]) + else: + print( + "Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges) + ) + simplifier.process_parent_edges(parent_edges) + parent_edges = [] + edge_offset += 1 + if len(parent_edges) > 0: + print("Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges)) + simplifier.process_parent_edges(parent_edges) + return edge_offset + + +def simplify_with_births(tables, births, alive, verbose): + total_edges = len(tables.edges) + for edges in births.edges.values(): + total_edges += len(edges) + if verbose > 0: + print("Simplify with births") + # print(births) + print("total_input edges = ", total_edges) + print("alive = ", alive) + print("\ttable edges:", len(tables.edges)) + print("\ttable nodes:", len(tables.nodes)) + + simplifier = simplify.Simplifier(tables.tree_sequence(), alive) + nodes_time = tables.nodes.time + # This should be almost sorted, because + parent_time = nodes_time[births.parents] + index = np.argsort(parent_time) + print(index) + offset = 0 + for parent in np.array(births.parents)[index]: + offset = add_younger_edges_to_simplifier( + simplifier, nodes_time[parent], tables, offset + ) + edges = [ + tskit.Edge(left, right, parent, child) + for child, left, right in sorted(births.edges[parent]) + ] + # print("Adding parent from time", nodes_time[parent], len(edges)) + # print("edges = ", edges) + simplifier.process_parent_edges(edges) + # simplifier.print_state() + + # FIXME should probably reuse the add_younger_edges_to_simplifier function + # for this - doesn't quite seem to work though + for _, edges in itertools.groupby(tables.edges[offset:], lambda e: e.parent): + edges = list(edges) + simplifier.process_parent_edges(edges) + + simplifier.check_state() + assert simplifier.parent_edges_processed == total_edges + # if simplifier.parent_edges_processed != total_edges: + # print("HERE!!!!", total_edges) + simplifier.finalise() + + tables.nodes.replace_with(simplifier.tables.nodes) + tables.edges.replace_with(simplifier.tables.edges) + + # This is needed because we call .tree_sequence here and later. + # Can be removed is we change the Simplifier to take a set of + # tables which it modifies, like the C version. + tables.drop_index() + # Just to check + tables.tree_sequence() + + births.clear() + # Add back all the edges with an alive parent to the buffer, so that + # we store them contiguously + keep = np.ones(len(tables.edges), dtype=bool) + for u in alive: + u = simplifier.node_id_map[u] + for e in np.where(tables.edges.parent == u)[0]: + keep[e] = False + edge = tables.edges[e] + # print(edge) + births.add_edge(edge.left, edge.right, edge.parent, edge.child) + + if verbose > 0: + print("Done") + print(births) + print("\ttable edges:", len(tables.edges)) + print("\ttable nodes:", len(tables.nodes)) + + +def simplify_with_births_easy(tables, births, alive, verbose): + for parent, edges in births.edges.items(): + for child, left, right in edges: + tables.edges.add_row(left, right, parent, child) + tables.sort() + tables.simplify(alive) + births.clear() + + # print(tables.nodes.time[tables.edges.parent]) + + +def wright_fisher( + N, *, death_proba=1, L=1, T=10, simplify_interval=1, seed=42, verbose=0 +): + rng = random.Random(seed) + tables = tskit.TableCollection(L) + alive = [tables.nodes.add_row(time=T) for _ in range(N)] + births = BirthBuffer() + + t = T + while t > 0: + t -= 1 + next_alive = list(alive) + for j in range(N): + if rng.random() < death_proba: + # alive[j] is dead - replace it. + u = tables.nodes.add_row(time=t) + next_alive[j] = u + a = rng.randint(0, N - 1) + b = rng.randint(0, N - 1) + x = rng.uniform(0, L) + # TODO Possibly more natural do this like + # births.add(u, parents=[a, b], breaks=[0, x, L]) + births.add_edge(0, x, alive[a], u) + births.add_edge(x, L, alive[b], u) + alive = next_alive + if t % simplify_interval == 0 or t == 0: + simplify_with_births(tables, births, alive, verbose=verbose) + # simplify_with_births_easy(tables, births, alive, verbose=verbose) + alive = list(range(N)) + # print(tables.tree_sequence()) + return tables.tree_sequence() + + +class TestSimulationBasics: + """ + Check that the basic simulation algorithm roughly works, so we're not building + on sand. + """ + + @pytest.mark.parametrize("N", [1, 10, 100]) + def test_pop_size(self, N): + ts = wright_fisher(N, simplify_interval=100) + assert ts.num_samples == N + + @pytest.mark.parametrize("T", [1, 10, 100]) + def test_time(self, T): + N = 10 + ts = wright_fisher(N=N, T=T, simplify_interval=1000) + assert np.all(ts.nodes_time[ts.samples()] == 0) + # Can't really assert anything much stronger, not really trying to + # do anything particularly rigorous here + assert np.max(ts.nodes_time) > 0 + + def test_death_proba_0(self): + N = 10 + T = 5 + ts = wright_fisher(N=N, T=T, death_proba=0, simplify_interval=1000) + assert ts.num_nodes == N + + @pytest.mark.parametrize("seed", [1, 5, 1234]) + def test_seed_identical(self, seed): + N = 10 + T = 5 + ts1 = wright_fisher(N=N, T=T, simplify_interval=1000, seed=seed) + ts2 = wright_fisher(N=N, T=T, simplify_interval=1000, seed=seed) + ts1.tables.assert_equals(ts2.tables, ignore_provenance=True) + ts3 = wright_fisher(N=N, T=T, simplify_interval=1000, seed=seed - 1) + assert not ts3.tables.equals(ts2.tables, ignore_provenance=True) + + def test_full_simulation(self): + ts = wright_fisher(N=5, T=500, death_proba=0.9, simplify_interval=1000) + for tree in ts.trees(): + assert tree.num_roots == 1 + + +class TestSimplifyIntervals: + @pytest.mark.parametrize("interval", [1, 10, 33, 100]) + def test_non_overlapping_generations(self, interval): + N = 10 + ts = wright_fisher(N, T=100, death_proba=1, simplify_interval=interval) + assert ts.num_samples == N + + @pytest.mark.parametrize("interval", [1, 10, 33, 100]) + @pytest.mark.parametrize("death_proba", [0.33, 0.5, 0.9]) + def test_overlapping_generations(self, interval, death_proba): + N = 4 + ts = wright_fisher( + N, T=20, death_proba=death_proba, simplify_interval=interval, verbose=1 + ) + assert ts.num_samples == N + print() + print(ts.draw_text())