diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 35991288d4..2d5bc97fd5 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -262,6 +262,48 @@ verify_mean_descendants(tsk_treeseq_t *ts) free(C); } +/* Check the divergence matrix by running against the stats API equivalent + * code. NOTE: this will not always be equal in site mode, because of a slightly + * different definition wrt to multiple mutations at a site. + */ +static void +verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t mode) +{ + int ret; + const tsk_size_t n = tsk_treeseq_get_num_samples(ts); + const tsk_id_t *samples = tsk_treeseq_get_samples(ts); + tsk_size_t sample_set_sizes[n]; + tsk_id_t index_tuples[2 * n * n]; + double D1[n * n], D2[n * n]; + tsk_size_t i, j, k; + + for (j = 0; j < n; j++) { + sample_set_sizes[j] = 1; + for (k = 0; k < n; k++) { + index_tuples[2 * (j * n + k)] = (tsk_id_t) j; + index_tuples[2 * (j * n + k) + 1] = (tsk_id_t) k; + } + } + ret = tsk_treeseq_divergence( + ts, n, sample_set_sizes, samples, n * n, index_tuples, 0, NULL, mode, D1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_divergence_matrix(ts, 0, NULL, 0, NULL, mode, D2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < n; j++) { + for (k = 0; k < n; k++) { + i = j * n + k; + /* printf("%d\t%d\t%f\t%f\n", (int) j, (int) k, D1[i], D2[i]); */ + if (j == k) { + CU_ASSERT_EQUAL(D2[i], 0); + } else { + CU_ASSERT_DOUBLE_EQUAL(D1[i], D2[i], 1E-6); + } + } + } +} + typedef struct { int call_count; int error_on; @@ -973,6 +1015,128 @@ test_single_tree_general_stat_errors(void) tsk_treeseq_free(&ts); } +static void +test_single_tree_divergence_matrix(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 6, 6, 2, 0, 6, 6, 6, 6, 0, 4, 6, 6, 4, 0 }; + double D_site[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_internal_samples(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D[16] = { 0, 2, 4, 3, 2, 0, 4, 3, 4, 4, 0, 1, 3, 3, 1, 0 }; + + const char *nodes = "1 0 -1 -1\n" /* 2.00┊ 6 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏━┻━┓ ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 5* ┊ */ + "0 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "1 1 -1 -1\n" /* 0 * * * 1 */ + "0 2 -1 -1\n"; + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n" + "0 1 6 4,5\n"; + /* One mutations per branch so we get the same as the branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n" + "0.5 A\n" + "0.6 A\n"; + const char *mutations = "0 0 T -1\n" + "1 1 T -1\n" + "2 2 T -1\n" + "3 3 T -1\n" + "4 4 T -1\n" + "5 5 T -1\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_multi_root(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 }; + double D_site[16] = { 0, 4, 6, 6, 4, 0, 6, 6, 6, 6, 0, 8, 6, 6, 8, 0 }; + + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" /* 2.00┊ 5 ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "0 2 -1 -1\n"; /* 0 * * * * 1 */ + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n"; + /* Two mutations per branch unit so we get twice branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n"; + const char *mutations = "0 0 B -1\n" + "0 0 C 0\n" + "1 1 B -1\n" + "1 1 C 2\n" + "2 2 B -1\n" + "2 2 C 4\n" + "2 2 D 5\n" + "2 2 E 6\n" + "3 3 B -1\n" + "3 3 C 8\n" + "3 3 D 9\n" + "3 3 E 10\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + tsk_treeseq_free(&ts); +} + static void test_paper_ex_ld(void) { @@ -1592,6 +1756,20 @@ test_paper_ex_afs(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + static void test_nonbinary_ex_ld(void) { @@ -1726,6 +1904,158 @@ test_ld_silent_mutations(void) free(base_ts); } +static void +test_simplest_divergence_matrix(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[4] = { 0, 2, 2, 0 }; + double D_site[4] = { 0, 0, 0, 0 }; + double result[4]; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SPAN_NORMALISE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_POLARISED_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); + + sample_ids[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + sample_ids[0] = 3; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_windows(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[8] = { 0, 1, 1, 0, 0, 1, 1, 0 }; + double D_site[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + double result[8]; + double windows[] = { 0, 0.5, 1 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_site, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 2, windows, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + + windows[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.45; + windows[2] = 1.5; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.55; + windows[2] = 1.0; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_internal_sample(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1, 2 }; + double result[9]; + double D_branch[9] = { 0, 2, 1, 2, 0, 1, 1, 1, 0 }; + double D_site[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_branch, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_site, result); + + tsk_treeseq_free(&ts); +} + +static void +test_multiroot_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, multiroot_ex_nodes, multiroot_ex_edges, NULL, + multiroot_ex_sites, multiroot_ex_mutations, NULL, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + int main(int argc, char **argv) { @@ -1745,6 +2075,11 @@ main(int argc, char **argv) test_single_tree_genealogical_nearest_neighbours }, { "test_single_tree_general_stat", test_single_tree_general_stat }, { "test_single_tree_general_stat_errors", test_single_tree_general_stat_errors }, + { "test_single_tree_divergence_matrix", test_single_tree_divergence_matrix }, + { "test_single_tree_divergence_matrix_internal_samples", + test_single_tree_divergence_matrix_internal_samples }, + { "test_single_tree_divergence_matrix_multi_root", + test_single_tree_divergence_matrix_multi_root }, { "test_paper_ex_ld", test_paper_ex_ld }, { "test_paper_ex_mean_descendants", test_paper_ex_mean_descendants }, @@ -1785,6 +2120,7 @@ main(int argc, char **argv) { "test_paper_ex_f4", test_paper_ex_f4 }, { "test_paper_ex_afs_errors", test_paper_ex_afs_errors }, { "test_paper_ex_afs", test_paper_ex_afs }, + { "test_paper_ex_divergence_matrix", test_paper_ex_divergence_matrix }, { "test_nonbinary_ex_ld", test_nonbinary_ex_ld }, { "test_nonbinary_ex_mean_descendants", test_nonbinary_ex_mean_descendants }, @@ -1798,6 +2134,13 @@ main(int argc, char **argv) { "test_ld_multi_mutations", test_ld_multi_mutations }, { "test_ld_silent_mutations", test_ld_silent_mutations }, + { "test_simplest_divergence_matrix", test_simplest_divergence_matrix }, + { "test_simplest_divergence_matrix_windows", + test_simplest_divergence_matrix_windows }, + { "test_simplest_divergence_matrix_internal_sample", + test_simplest_divergence_matrix_internal_sample }, + { "test_multiroot_divergence_matrix", test_multiroot_divergence_matrix }, + { NULL, NULL }, }; return test_main(tests, argc, argv); diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 94e33ee487..cceb11d6fd 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -5395,7 +5395,6 @@ test_simplify_keep_input_roots_multi_tree(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - tsk_treeseq_dump(&ts, "tmp.trees", 0); ret = tsk_treeseq_simplify( &ts, samples, 2, TSK_SIMPLIFY_KEEP_INPUT_ROOTS, &simplified, NULL); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -7801,7 +7800,7 @@ test_time_uncalibrated(void) tsk_size_t sample_set_sizes[] = { 2, 2 }; tsk_id_t samples[] = { 0, 1, 2, 3 }; tsk_size_t num_samples; - double result[10]; + double result[100]; double *W; double *sigma; @@ -7857,6 +7856,12 @@ test_time_uncalibrated(void) TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, sigma); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TIME_UNCALIBRATED); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, + TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_safe_free(W); tsk_safe_free(sigma); tsk_treeseq_free(&ts); diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 823068d136..043ae5ceab 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -966,6 +966,16 @@ tskit_suite_init(void) return CUE_SUCCESS; } +void +assert_arrays_almost_equal(tsk_size_t len, double *a, double *b) +{ + tsk_size_t j; + + for (j = 0; j < len; j++) { + CU_ASSERT_DOUBLE_EQUAL(a[j], b[j], 1e-9); + } +} + static int tskit_suite_cleanup(void) { diff --git a/c/tests/testlib.h b/c/tests/testlib.h index d042d60b55..69efb14781 100644 --- a/c/tests/testlib.h +++ b/c/tests/testlib.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2021 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -54,6 +54,8 @@ void parse_individuals(const char *text, tsk_individual_table_t *individual_tabl void unsort_edges(tsk_edge_table_t *edges, size_t start); +void assert_arrays_almost_equal(tsk_size_t len, double *a, double *b); + extern const char *single_tree_ex_nodes; extern const char *single_tree_ex_edges; extern const char *single_tree_ex_sites; diff --git a/c/tskit/core.c b/c/tskit/core.c index b1ea25badd..100cc78cad 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -466,6 +466,15 @@ tsk_strerror_internal(int err) ret = "Statistics using branch lengths cannot be calculated when time_units " "is 'uncalibrated'. (TSK_ERR_TIME_UNCALIBRATED)"; break; + case TSK_ERR_STAT_POLARISED_UNSUPPORTED: + ret = "The TSK_STAT_POLARISED option is not supported by this statistic. " + "(TSK_ERR_STAT_POLARISED_UNSUPPORTED)"; + break; + case TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED: + ret = "The TSK_STAT_SPAN_NORMALISE option is not supported by this " + "statistic. " + "(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)"; + break; /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: diff --git a/c/tskit/core.h b/c/tskit/core.h index b8b9f354ba..4d2c95212d 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -675,6 +675,16 @@ Statistics based on branch lengths were attempted when the ``time_units`` were ``uncalibrated``. */ #define TSK_ERR_TIME_UNCALIBRATED -910 +/** +The TSK_STAT_POLARISED option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_POLARISED_UNSUPPORTED -911 +/** +The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 4604579e0b..cd0ad36aa2 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1191,9 +1191,11 @@ tsk_treeseq_mean_descendants(const tsk_treeseq_t *self, * General stats framework ***********************************/ +#define TSK_REQUIRE_FULL_SPAN 1 + static int -tsk_treeseq_check_windows( - const tsk_treeseq_t *self, tsk_size_t num_windows, const double *windows) +tsk_treeseq_check_windows(const tsk_treeseq_t *self, tsk_size_t num_windows, + const double *windows, tsk_flags_t options) { int ret = TSK_ERR_BAD_WINDOWS; tsk_size_t j; @@ -1202,12 +1204,23 @@ tsk_treeseq_check_windows( ret = TSK_ERR_BAD_NUM_WINDOWS; goto out; } - /* TODO these restrictions can be lifted later if we want a specific interval. */ - if (windows[0] != 0) { - goto out; - } - if (windows[num_windows] != self->tables->sequence_length) { - goto out; + if (options & TSK_REQUIRE_FULL_SPAN) { + /* TODO the general stat code currently requires that we include the + * entire tree sequence span. This should be relaxed, so hopefully + * this branch (and the option) can be removed at some point */ + if (windows[0] != 0) { + goto out; + } + if (windows[num_windows] != self->tables->sequence_length) { + goto out; + } + } else { + if (windows[0] < 0) { + goto out; + } + if (windows[num_windows] > self->tables->sequence_length) { + goto out; + } } for (j = 0; j < num_windows; j++) { if (windows[j] >= windows[j + 1]) { @@ -1960,7 +1973,8 @@ tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -2468,7 +2482,7 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); - double default_windows[] = { 0, self->tables->sequence_length }; + const double default_windows[] = { 0, self->tables->sequence_length }; const tsk_size_t num_nodes = self->tables->nodes.num_rows; const tsk_size_t K = num_sample_sets + 1; tsk_size_t j, k, l, afs_size; @@ -2496,7 +2510,8 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -3331,7 +3346,7 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, } ret = tsk_treeseq_init( output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); - /* Once tsk_tree_init has returned ownership of tables is transferred */ + /* Once tsk_treeseq_init has returned ownership of tables is transferred */ tables = NULL; out: if (tables != NULL) { @@ -3460,6 +3475,20 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag * Tree * ======================================================== */ +/* Return the root for the specified node. + * NOTE: no bounds checking is done here. + */ +static tsk_id_t +tsk_tree_get_node_root(const tsk_tree_t *self, tsk_id_t u) +{ + const tsk_id_t *restrict parent = self->parent; + + while (parent[u] != TSK_NULL) { + u = parent[u]; + } + return u; +} + int TSK_WARN_UNUSED tsk_tree_init(tsk_tree_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) { @@ -6009,3 +6038,526 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, } return ret; } + +/* + * Divergence matrix + */ + +typedef struct { + /* Note it's a waste storing the triply linked tree here, but the code + * is written on the assumption of 1-based trees and the algorithm is + * frighteningly subtle, so it doesn't seem worth messing with it + * unless we really need to save some memory */ + tsk_id_t *parent; + tsk_id_t *child; + tsk_id_t *sib; + tsk_id_t *lambda; + tsk_id_t *pi; + tsk_id_t *tau; + tsk_id_t *beta; + tsk_id_t *alpha; +} sv_tables_t; + +static int +sv_tables_init(sv_tables_t *self, tsk_size_t n) +{ + int ret = 0; + + self->parent = tsk_malloc(n * sizeof(*self->parent)); + self->child = tsk_malloc(n * sizeof(*self->child)); + self->sib = tsk_malloc(n * sizeof(*self->sib)); + self->pi = tsk_malloc(n * sizeof(*self->pi)); + self->lambda = tsk_malloc(n * sizeof(*self->lambda)); + self->tau = tsk_malloc(n * sizeof(*self->tau)); + self->beta = tsk_malloc(n * sizeof(*self->beta)); + self->alpha = tsk_malloc(n * sizeof(*self->alpha)); + if (self->parent == NULL || self->child == NULL || self->sib == NULL + || self->lambda == NULL || self->tau == NULL || self->beta == NULL + || self->alpha == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +sv_tables_free(sv_tables_t *self) +{ + tsk_safe_free(self->parent); + tsk_safe_free(self->child); + tsk_safe_free(self->sib); + tsk_safe_free(self->lambda); + tsk_safe_free(self->pi); + tsk_safe_free(self->tau); + tsk_safe_free(self->beta); + tsk_safe_free(self->alpha); + return 0; +} +static void +sv_tables_reset(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + tsk_memset(self->parent, 0, n * sizeof(*self->parent)); + tsk_memset(self->child, 0, n * sizeof(*self->child)); + tsk_memset(self->sib, 0, n * sizeof(*self->sib)); + tsk_memset(self->pi, 0, n * sizeof(*self->pi)); + tsk_memset(self->lambda, 0, n * sizeof(*self->lambda)); + tsk_memset(self->tau, 0, n * sizeof(*self->tau)); + tsk_memset(self->beta, 0, n * sizeof(*self->beta)); + tsk_memset(self->alpha, 0, n * sizeof(*self->alpha)); +} + +static void +sv_tables_convert_tree(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + const tsk_id_t *restrict tsk_parent = tree->parent; + tsk_id_t *restrict child = self->child; + tsk_id_t *restrict parent = self->parent; + tsk_id_t *restrict sib = self->sib; + tsk_size_t j; + tsk_id_t u, v; + + for (j = 0; j < n - 1; j++) { + u = (tsk_id_t) j + 1; + v = tsk_parent[j] + 1; + sib[u] = child[v]; + child[v] = u; + parent[u] = v; + } +} + +#define LAMBDA 0 + +static void +sv_tables_build_index(sv_tables_t *self) +{ + const tsk_id_t *restrict child = self->child; + const tsk_id_t *restrict parent = self->parent; + const tsk_id_t *restrict sib = self->sib; + tsk_id_t *restrict lambda = self->lambda; + tsk_id_t *restrict pi = self->pi; + tsk_id_t *restrict tau = self->tau; + tsk_id_t *restrict beta = self->beta; + tsk_id_t *restrict alpha = self->alpha; + tsk_id_t a, n, p, h; + + p = child[LAMBDA]; + n = 0; + lambda[0] = -1; + while (p != LAMBDA) { + while (true) { + n++; + pi[p] = n; + tau[n] = LAMBDA; + lambda[n] = 1 + lambda[n >> 1]; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + beta[p] = n; + while (true) { + tau[beta[p]] = parent[p]; + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p != LAMBDA) { + h = lambda[n & -pi[p]]; + beta[p] = ((n >> h) | 1) << h; + } else { + break; + } + } + } + } + + /* Begin the second traversal */ + lambda[0] = lambda[n]; + pi[LAMBDA] = 0; + beta[LAMBDA] = 0; + alpha[LAMBDA] = 0; + p = child[LAMBDA]; + while (p != LAMBDA) { + while (true) { + a = alpha[parent[p]] | (beta[p] & -beta[p]); + alpha[p] = a; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + while (true) { + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p == LAMBDA) { + break; + } + } + } + } +} + +static void +sv_tables_build(sv_tables_t *self, tsk_tree_t *tree) +{ + sv_tables_reset(self, tree); + sv_tables_convert_tree(self, tree); + sv_tables_build_index(self); +} + +static tsk_id_t +sv_tables_mrca_one_based(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + const tsk_id_t *restrict lambda = self->lambda; + const tsk_id_t *restrict pi = self->pi; + const tsk_id_t *restrict tau = self->tau; + const tsk_id_t *restrict beta = self->beta; + const tsk_id_t *restrict alpha = self->alpha; + tsk_id_t h, k, xhat, yhat, ell, j, z; + + if (beta[x] <= beta[y]) { + h = lambda[beta[y] & -beta[x]]; + } else { + h = lambda[beta[x] & -beta[y]]; + } + k = alpha[x] & alpha[y] & -(1 << h); + h = lambda[k & -k]; + j = ((beta[x] >> h) | 1) << h; + if (j == beta[x]) { + xhat = x; + } else { + ell = lambda[alpha[x] & ((1 << h) - 1)]; + xhat = tau[((beta[x] >> ell) | 1) << ell]; + } + if (j == beta[y]) { + yhat = y; + } else { + ell = lambda[alpha[y] & ((1 << h) - 1)]; + yhat = tau[((beta[y] >> ell) | 1) << ell]; + } + if (pi[xhat] <= pi[yhat]) { + z = xhat; + } else { + z = yhat; + } + return z; +} + +static tsk_id_t +sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + /* Convert to 1-based indexes and back */ + return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1; +} + +static int +tsk_treeseq_check_node_bounds( + const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows; + + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + if (u < 0 || u >= N) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } +out: + return ret; +} + +static int +tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t options, double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const double *restrict nodes_time = self->tables->nodes.time; + const tsk_size_t n = num_samples; + tsk_size_t i, j, k; + tsk_id_t u, v, w, u_root, v_root; + double tu, tv, d, span, left, right, span_left, span_right; + double *restrict D; + sv_tables_t sv; + + memset(&sv, 0, sizeof(sv)); + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + ret = sv_tables_init(&sv, self->tables->nodes.num_rows + 1); + if (ret != 0) { + goto out; + } + + if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { + ret = TSK_ERR_TIME_UNCALIBRATED; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + span = span_right - span_left; + sv_tables_build(&sv, &tree); + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + w = sv_tables_mrca(&sv, u, v); + if (w != TSK_NULL) { + u_root = w; + v_root = w; + } else { + /* Slow path - only happens for nodes in disconnected + * subtrees in a tree with multiple roots */ + u_root = tsk_tree_get_node_root(&tree, u); + v_root = tsk_tree_get_node_root(&tree, v); + } + tu = nodes_time[u_root] - nodes_time[u]; + tv = nodes_time[v_root] - nodes_time[v]; + d = (tu + tv) * span; + D[j * n + k] += d; + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + sv_tables_free(&sv); + return ret; +} + +static tsk_size_t +count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, + const double *restrict time, const tsk_size_t *restrict mutations_per_node) +{ + double tu, tv; + tsk_size_t count = 0; + + tu = time[u]; + tv = time[v]; + while (u != v) { + if (tu < tv) { + count += mutations_per_node[u]; + u = parent[u]; + if (u == TSK_NULL) { + break; + } + tu = time[u]; + } else { + count += mutations_per_node[v]; + v = parent[v]; + if (v == TSK_NULL) { + break; + } + tv = time[v]; + } + } + if (u != v) { + while (u != TSK_NULL) { + count += mutations_per_node[u]; + u = parent[u]; + } + while (v != TSK_NULL) { + count += mutations_per_node[v]; + v = parent[v]; + } + } + return count; +} + +static int +tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t TSK_UNUSED(options), + double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const tsk_size_t n = num_samples; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const double *restrict nodes_time = self->tables->nodes.time; + tsk_size_t i, j, k, tree_site, tree_mut; + tsk_site_t site; + tsk_mutation_t mut; + tsk_id_t u, v; + double left, right, span_left, span_right; + double *restrict D; + tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node)); + + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + if (mutations_per_node == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + + /* NOTE: we could avoid this full memset across all nodes by doing + * the same loops again and decrementing at the end of the main + * tree-loop. It's probably not worth it though, because of the + * overwhelming O(n^2) below */ + tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node)); + for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { + site = tree.sites[tree_site]; + if (span_left <= site.position && site.position < span_right) { + for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { + mut = site.mutations[tree_mut]; + mutations_per_node[mut.node]++; + } + } + } + + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + D[j * n + k] += (double) count_mutations_on_path( + u, v, tree.parent, nodes_time, mutations_per_node); + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + tsk_safe_free(mutations_per_node); + return ret; +} + +static void +fill_lower_triangle( + double *restrict result, const tsk_size_t n, const tsk_size_t num_windows) +{ + tsk_size_t i, j, k; + double *restrict D; + + /* TODO there's probably a better striding pattern that could be used here */ + for (i = 0; i < num_windows; i++) { + D = result + i * n * n; + for (j = 0; j < n; j++) { + for (k = j + 1; k < n; k++) { + D[k * n + j] = D[j * n + k]; + } + } + } +} + +int +tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result) +{ + int ret = 0; + const tsk_id_t *samples = self->samples; + tsk_size_t n = self->num_samples; + const double default_windows[] = { 0, self->tables->sequence_length }; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + bool stat_node = !!(options & TSK_STAT_NODE); + + if (stat_node) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } + /* If no mode is specified, we default to site mode */ + if (!(stat_site || stat_branch)) { + stat_site = true; + } + /* It's an error to specify more than one mode */ + if (stat_site + stat_branch > 1) { + ret = TSK_ERR_MULTIPLE_STAT_MODES; + goto out; + } + + if (options & TSK_STAT_POLARISED) { + ret = TSK_ERR_STAT_POLARISED_UNSUPPORTED; + goto out; + } + if (options & TSK_STAT_SPAN_NORMALISE) { + ret = TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED; + goto out; + } + + if (windows == NULL) { + num_windows = 1; + windows = default_windows; + } else { + ret = tsk_treeseq_check_windows(self, num_windows, windows, 0); + if (ret != 0) { + goto out; + } + } + + if (samples_in != NULL) { + samples = samples_in; + n = num_samples; + ret = tsk_treeseq_check_node_bounds(self, n, samples); + if (ret != 0) { + goto out; + } + } + + tsk_memset(result, 0, num_windows * n * n * sizeof(*result)); + + if (stat_branch) { + ret = tsk_treeseq_divergence_matrix_branch( + self, n, samples, num_windows, windows, options, result); + } else { + tsk_bug_assert(stat_site); + ret = tsk_treeseq_divergence_matrix_site( + self, n, samples, num_windows, windows, options, result); + } + if (ret != 0) { + goto out; + } + fill_lower_triangle(result, n, num_windows); + +out: + return ret; +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index efe9980077..10c820e1c0 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1003,6 +1003,10 @@ int tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result); + /****************************************************************************/ /* Tree */ /****************************************************************************/ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index dea3c03fd9..8d42b50afc 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9638,6 +9638,78 @@ TreeSequence_f4(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_stat_method(self, args, kwds, 4, tsk_treeseq_f4); } +static PyObject * +TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "windows", "samples", "mode", NULL }; + PyArrayObject *result_array = NULL; + PyObject *windows = NULL; + PyObject *py_samples = Py_None; + char *mode = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *samples_array = NULL; + tsk_flags_t options = 0; + npy_intp *shape, dims[3]; + tsk_size_t num_samples, num_windows; + tsk_id_t *samples = NULL; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|Os", kwlist, &windows, &py_samples, &mode)) { + goto out; + } + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + if (py_samples != Py_None) { + samples_array = (PyArrayObject *) PyArray_FROMANY( + py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + samples = PyArray_DATA(samples_array); + num_samples = (tsk_size_t) shape[0]; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + dims[0] = num_windows; + dims[1] = num_samples; + dims[2] = num_samples; + result_array = (PyArrayObject *) PyArray_SimpleNew(3, dims, NPY_FLOAT64); + if (result_array == NULL) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = tsk_treeseq_divergence_matrix( + self->tree_sequence, + num_samples, samples, + num_windows, PyArray_DATA(windows_array), + options, PyArray_DATA(result_array)); + Py_END_ALLOW_THREADS + // clang-format on + /* Clang-format insists on doing this in spite of the "off" instruction above */ + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(result_array); + Py_XDECREF(windows_array); + Py_XDECREF(samples_array); + return ret; +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -10346,6 +10418,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_f4, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the f4 statistic." }, + { .ml_name = "divergence_matrix", + .ml_meth = (PyCFunction) TreeSequence_divergence_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the pairwise divergence matrix." }, { .ml_name = "split_edges", .ml_meth = (PyCFunction) TreeSequence_split_edges, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py new file mode 100644 index 0000000000..acb2403d41 --- /dev/null +++ b/python/tests/test_divmat.py @@ -0,0 +1,1064 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for divergence matrix based pairwise stats +""" +import collections + +import msprime +import numpy as np +import pytest + +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + +DIVMAT_MODES = ["branch", "site"] + +# NOTE: this implementation of Schieber-Vishkin algorithm is done like +# this so it's easy to run with numba. It would be more naturally +# packaged as a class. We don't actually use numba here, but it's +# handy to have a version of the SV code lying around that can be +# run directly with numba. + + +def sv_tables_init(parent_array): + n = 1 + parent_array.shape[0] + + LAMBDA = 0 + # Triply-linked tree. FIXME we shouldn't need to build this as it's + # available already in tskit + child = np.zeros(n, dtype=np.int32) + parent = np.zeros(n, dtype=np.int32) + sib = np.zeros(n, dtype=np.int32) + + for j in range(n - 1): + u = j + 1 + v = parent_array[j] + 1 + sib[u] = child[v] + child[v] = u + parent[u] = v + + lambd = np.zeros(n, dtype=np.int32) + pi = np.zeros(n, dtype=np.int32) + tau = np.zeros(n, dtype=np.int32) + beta = np.zeros(n, dtype=np.int32) + alpha = np.zeros(n, dtype=np.int32) + + p = child[LAMBDA] + n = 0 + lambd[0] = -1 + while p != LAMBDA: + while True: + n += 1 + pi[p] = n + tau[n] = LAMBDA + lambd[n] = 1 + lambd[n >> 1] + if child[p] != LAMBDA: + p = child[p] + else: + break + beta[p] = n + while True: + tau[beta[p]] = parent[p] + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p != LAMBDA: + h = lambd[n & -pi[p]] + beta[p] = ((n >> h) | 1) << h + else: + break + + # Begin the second traversal + lambd[0] = lambd[n] + pi[LAMBDA] = 0 + beta[LAMBDA] = 0 + alpha[LAMBDA] = 0 + p = child[LAMBDA] + while p != LAMBDA: + while True: + a = alpha[parent[p]] | (beta[p] & -beta[p]) + alpha[p] = a + if child[p] != LAMBDA: + p = child[p] + else: + break + while True: + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p == LAMBDA: + break + + return lambd, pi, tau, beta, alpha + + +def _sv_mrca(x, y, lambd, pi, tau, beta, alpha): + if beta[x] <= beta[y]: + h = lambd[beta[y] & -beta[x]] + else: + h = lambd[beta[x] & -beta[y]] + k = alpha[x] & alpha[y] & -(1 << h) + h = lambd[k & -k] + j = ((beta[x] >> h) | 1) << h + if j == beta[x]: + xhat = x + else: + ell = lambd[alpha[x] & ((1 << h) - 1)] + xhat = tau[((beta[x] >> ell) | 1) << ell] + if j == beta[y]: + yhat = y + else: + ell = lambd[alpha[y] & ((1 << h) - 1)] + yhat = tau[((beta[y] >> ell) | 1) << ell] + if pi[xhat] <= pi[yhat]: + z = xhat + else: + z = yhat + return z + + +def sv_mrca(x, y, lambd, pi, tau, beta, alpha): + # Convert to 1-based indexes + return _sv_mrca(x + 1, y + 1, lambd, pi, tau, beta, alpha) - 1 + + +def local_root(tree, u): + while tree.parent(u) != tskit.NULL: + u = tree.parent(u) + return u + + +def branch_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + # print(f"WINDOW {i} [{left}, {right})") + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + span = span_right - span_left + # print(f"\ttree {tree.interval} [{span_left}, {span_right})") + tables = sv_tables_init(tree.parent_array) + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = sv_mrca(u, v, *tables) + assert w == tree.mrca(u, v) + if w != tskit.NULL: + tu = ts.nodes_time[w] - ts.nodes_time[u] + tv = ts.nodes_time[w] - ts.nodes_time[v] + else: + tu = ts.nodes_time[local_root(tree, u)] - ts.nodes_time[u] + tv = ts.nodes_time[local_root(tree, v)] - ts.nodes_time[v] + d = (tu + tv) * span + D[i, j, k] += d + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def divergence_matrix(ts, windows=None, samples=None, mode="site"): + assert mode in ["site", "branch"] + if mode == "site": + return site_divergence_matrix(ts, samples=samples, windows=windows) + else: + return branch_divergence_matrix(ts, samples=samples, windows=windows) + + +def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="site"): + samples = ts.samples() if samples is None else samples + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else list(windows) + num_windows = len(windows) - 1 + + if len(samples) == 0: + # FIXME: the code general stat code doesn't seem to handle zero samples + # case, need to identify MWE and file issue. + if windows_specified: + return np.zeros(shape=(num_windows, 0, 0)) + else: + return np.zeros(shape=(0, 0)) + + # Make sure that all the specified samples have the sample flag set, otherwise + # the library code will complain + tables = ts.dump_tables() + flags = tables.nodes.flags + # NOTE: this is a shortcut, setting all flags unconditionally to zero, so don't + # use this tree sequence outside this method. + flags[:] = 0 + flags[samples] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + + # FIXME We have to go through this annoying rigmarole because windows must start and + # end with 0 and L. We should relax this requirement to just making the windows + # contiguous, so that we just look at specific sections of the genome. + drop = [] + if windows[0] != 0: + windows = [0] + windows + drop.append(0) + if windows[-1] != ts.sequence_length: + windows.append(ts.sequence_length) + drop.append(-1) + + n = len(samples) + sample_sets = [[u] for u in samples] + indexes = [(i, j) for i in range(n) for j in range(n)] + X = ts.divergence( + sample_sets, + indexes=indexes, + mode=mode, + span_normalise=False, + windows=windows, + ) + keep = np.ones(len(windows) - 1, dtype=bool) + keep[drop] = False + X = X[keep] + out = X.reshape((X.shape[0], n, n)) + for D in out: + np.fill_diagonal(D, 0) + if not windows_specified: + out = out[0] + return out + + +def rootward_path(tree, u, v): + while u != v: + yield u + u = tree.parent(u) + + +def site_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + mutations_per_node = collections.Counter() + for site in tree.sites(): + if span_left <= site.position < span_right: + for mutation in site.mutations: + mutations_per_node[mutation.node] += 1 + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = tree.mrca(u, v) + if w != tskit.NULL: + wu = w + wv = w + else: + wu = local_root(tree, u) + wv = local_root(tree, v) + du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) + dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) + # NOTE: we're just accumulating the raw mutation counts, not + # multiplying by span + D[i, j, k] += du + dv + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def check_divmat( + ts, + *, + windows=None, + samples=None, + verbosity=0, + compare_stats_api=True, + compare_lib=True, + mode="site", +): + np.set_printoptions(linewidth=500, precision=4) + # print(ts.draw_text()) + if verbosity > 1: + print(ts.draw_text()) + + D1 = divergence_matrix(ts, windows=windows, samples=samples, mode=mode) + if compare_stats_api: + # Somethings like duplicate samples aren't worth hacking around for in + # stats API. + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + # print("windows = ", windows) + # print(D1) + # print(D2) + np.testing.assert_allclose(D1, D2) + assert D1.shape == D2.shape + if compare_lib: + D3 = ts.divergence_matrix(windows=windows, samples=samples, mode=mode) + # print(D3) + assert D1.shape == D3.shape + np.testing.assert_allclose(D1, D3) + return D1 + + +class TestExamplesWithAnswer: + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples(self, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + D = check_divmat(ts, samples=[], mode="site") + assert D.shape == (0, 0) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples_windows(self, num_windows, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D = check_divmat(ts, samples=[], windows=windows, mode="site") + assert D.shape == (num_windows, 0, 0) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_sites_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts, m) + D1 = check_divmat(ts, mode="site") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_mutations_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_mutations(ts, m) + # The stats API will produce a different value here, because + # we're just counting up the mutations and not reasoning about + # the state of samples at all. + D1 = check_divmat(ts, mode="site", compare_stats_api=False) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("L", [0.1, 1, 2, 100]) + def test_single_tree_sequence_length(self, L): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=L).tree_sequence + D1 = check_divmat(ts, mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, L * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_gap_at_end(self, num_windows, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 + # 0 1 2 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + tables = ts.dump_tables() + tables.sequence_length = 2 + ts = tables.tree_sequence() + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D1 = check_divmat(ts, windows=windows, mode=mode) + D1 = np.sum(D1, axis=0) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_subset_permuted_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[1, 2, 0], mode=mode) + D2 = np.array( + [ + [0.0, 4.0, 2.0], + [4.0, 0.0, 4.0], + [2.0, 4.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_mixed_non_sample_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 5], mode=mode) + D2 = np.array( + [ + [0.0, 3.0], + [3.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_duplicate_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 0, 1], compare_stats_api=False, mode=mode) + D2 = np.array( + [ + [0.0, 0.0, 2.0], + [0.0, 0.0, 2.0], + [2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_multiroot(self, mode): + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + ts = ts.decapitate(1) + D1 = check_divmat(ts, mode=mode) + D2 = np.array( + [ + [0.0, 2.0, 2.0, 2.0], + [2.0, 0.0, 2.0, 2.0], + [2.0, 2.0, 0.0, 2.0], + [2.0, 2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize( + ["left", "right"], [(0, 10), (1, 3), (3.25, 3.75), (5, 10)] + ) + def test_single_tree_interval(self, left, right): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + D1 = check_divmat(ts, windows=[left, right], mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1[0], (right - left) * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5, 11]) + def test_single_tree_equal_windows(self, num_windows): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + x = ts.sequence_length / num_windows + # print(windows) + D1 = check_divmat(ts, windows=windows, mode="branch") + assert D1.shape == (num_windows, 4, 4) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + for D in D1: + np.testing.assert_array_almost_equal(D, x * D2) + + @pytest.mark.parametrize("n", [2, 3, 5]) + def test_single_tree_no_sites(self, n): + ts = tskit.Tree.generate_balanced(n, span=10).tree_sequence + D = check_divmat(ts, mode="site") + np.testing.assert_array_equal(D, np.zeros((n, n))) + + +class TestExamples: + @pytest.mark.parametrize( + "interval", [(0, 26), (1, 3), (3.25, 13.75), (5, 10), (25.5, 26)] + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_interval(self, interval, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + check_divmat(ts, windows=interval, mode=mode) + + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows(self, windows, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + D = check_divmat(ts, windows=windows, mode=mode) + assert D.shape == (len(windows) - 1, 4, 4) + + @pytest.mark.parametrize("num_windows", [1, 5, 28]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows_gap_at_end(self, num_windows, mode): + tables = tsutil.all_trees_ts(4).dump_tables() + tables.sequence_length = 30 + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + assert ts.last().num_roots == 4 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5]) + @pytest.mark.parametrize("seed", range(1, 4)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_small_sims(self, n, seed, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + sequence_length=1000, + recombination_rate=0.01, + random_seed=seed, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, rate=0.1, discrete_genome=False, random_seed=seed + ) + assert ts.num_mutations > 1 + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_windows", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_sims_windows(self, n, num_windows, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=79234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, + rate=0.01, + discrete_genome=False, + random_seed=1234, + ) + assert ts.num_mutations >= 2 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_balanced_tree(self, n, mode): + ts = tskit.Tree.generate_balanced(n).tree_sequence + ts = tsutil.insert_branch_sites(ts) + # print(ts.draw_text()) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_internal_sample(self, mode): + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[3] = 0 + flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("seed", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_one_internal_sample_sims(self, seed, mode): + ts = msprime.sim_ancestry( + 10, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=seed, + ) + t = ts.dump_tables() + # Add a new sample directly below another sample + u = t.nodes.add_row(time=-1, flags=tskit.NODE_IS_SAMPLE) + t.edges.add_row(parent=0, child=u, left=0, right=ts.sequence_length) + t.sort() + t.build_index() + ts = t.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_missing_flanks(self, mode): + ts = msprime.sim_ancestry( + 20, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = ts.keep_intervals([[20, 80]]) + assert ts.first().interval == (0, 20) + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_samples(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in ts1.samples(): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_all(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in range(ts1.num_nodes): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_disconnected_non_sample_topology(self, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(5).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + # Add an extra bit of disconnected non-sample topology + u = tables.nodes.add_row(time=0) + v = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=v, child=u) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + +class TestSuiteExamples: + """ + Compare the stats API method vs the library implementation for the + suite test examples. Some of these examples are too large to run the + Python code above on. + """ + + def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): + D1 = ts.divergence_matrix( + windows=windows, + samples=samples, + num_threads=num_threads, + mode=mode, + ) + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + assert D1.shape == D2.shape + if mode == "branch": + # If we have missing data then parts of the divmat are defined to be zero, + # so relative tolerances aren't useful. Because the stats API + # method necessarily involves subtracting away all of the previous + # values for an empty tree, there is a degree of numerical imprecision + # here. This value for atol is what is needed to get the tests to + # pass in practise. + has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees()) + atol = 1e-12 if has_missing_data else 0 + np.testing.assert_allclose(D1, D2, atol=atol) + else: + assert mode == "site" + if np.any(ts.mutations_parent != tskit.NULL): + # The stats API computes something slightly different when we have + # recurrent mutations, so fall back to the naive version. + D2 = site_divergence_matrix(ts, windows=windows, samples=samples) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_defaults(self, ts, mode): + self.check(ts, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_subset_samples(self, ts, mode): + n = min(ts.num_samples, 2) + self.check(ts, samples=ts.samples()[:n], mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=13) + self.check(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_no_windows(self, ts, mode): + self.check(ts, num_threads=5, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=11) + self.check(ts, num_threads=5, windows=windows, mode=mode) + + +class TestThreadsNoWindows: + def check(self, ts, num_threads, samples=None, mode=None): + D1 = ts.divergence_matrix(num_threads=0, samples=samples, mode=mode) + D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, 2, samples, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, n, num_threads, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + self.check(ts, num_threads, mode=mode) + + +class TestThreadsWindows: + def check(self, ts, num_threads, *, windows, samples=None, mode=None): + D1 = ts.divergence_matrix( + num_threads=0, windows=windows, samples=samples, mode=mode + ) + D2 = ts.divergence_matrix( + num_threads=num_threads, windows=windows, samples=samples, mode=mode + ) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, windows, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, windows=windows, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + (None,), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, windows, mode): + ts = tsutil.all_trees_ts(4) + self.check(ts, 2, windows=windows, samples=samples, mode=mode) + + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 100],), + ([0, 50, 75, 95, 100],), + ([50, 75, 95, 100],), + ([0, 50, 75, 95],), + (list(range(100)),), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, num_threads, windows, mode): + ts = msprime.sim_ancestry( + 15, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1234) + assert ts.num_mutations > 10 + self.check(ts, num_threads, windows=windows, mode=mode) + + +# NOTE these are tests that are for more general functionality that might +# get applied across many different functions, and so probably should be +# tested in another file. For now they're only used by divmat, so we can +# keep them here for simplificity. +class TestChunkByTree: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 26]]), + (2, [[0, 13], [13, 26]]), + (3, [[0, 9], [9, 18], [18, 26]]), + (4, [[0, 7], [7, 14], [14, 20], [20, 26]]), + (5, [[0, 6], [6, 11], [11, 16], [16, 21], [21, 26]]), + ], + ) + def test_all_trees_ts_26(self, num_chunks, expected): + ts = tsutil.all_trees_ts(4) + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4(self, num_chunks, expected): + ts = tsutil.all_trees_ts(3) + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("span", [1, 2, 5, 0.3]) + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4_trees_span(self, span, num_chunks, expected): + tables = tsutil.all_trees_ts(3).dump_tables() + tables.edges.left *= span + tables.edges.right *= span + tables.sequence_length *= span + ts = tables.tree_sequence() + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, np.array(expected) * span) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_empty_ts(self, num_chunks): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, 1]]) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_single_tree(self, num_chunks): + L = 10 + ts = tskit.Tree.generate_balanced(2, span=L).tree_sequence + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, L]]) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + ts = tskit.Tree.generate_balanced(2).tree_sequence + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + ts._chunk_sequence_by_tree(num_chunks) + + +class TestChunkWindows: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["windows", "num_chunks", "expected"], + [ + ([0, 10], 1, [[0, 10]]), + ([0, 10], 2, [[0, 10]]), + ([0, 5, 10], 2, [[0, 5], [5, 10]]), + ([0, 5, 6, 10], 2, [[0, 5, 6], [6, 10]]), + ([0, 5, 6, 10], 3, [[0, 5], [5, 6], [6, 10]]), + ], + ) + def test_examples(self, windows, num_chunks, expected): + actual = tskit.TreeSequence._chunk_windows(windows, num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + tskit.TreeSequence._chunk_windows([0, 1], num_chunks) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index ce225f1dd7..0529dda001 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -228,14 +228,14 @@ def get_gap_examples(): assert len(t.parent_dict) == 0 found = True assert found - ret.append((f"gap {x}", ts)) + ret.append((f"gap_{x}", ts)) # Give an example with a gap at the end. ts = msprime.simulate(10, random_seed=5, recombination_rate=1) tables = get_table_collection_copy(ts.dump_tables(), 2) tables.sites.clear() tables.mutations.clear() insert_uniform_mutations(tables, 100, list(ts.samples())) - ret.append(("gap at end", tables.tree_sequence())) + ret.append(("gap_at_end", tables.tree_sequence())) return ret diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 70ef08143c..530ec7223f 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1529,6 +1529,26 @@ def test_kc_distance(self): x2 = ts2.get_kc_distance(ts1, lambda_) assert x1 == x2 + def test_divergence_matrix(self): + n = 10 + ts = self.get_example_tree_sequence(n, random_seed=12) + D = ts.divergence_matrix([0, ts.get_sequence_length()]) + assert D.shape == (1, n, n) + D = ts.divergence_matrix([0, ts.get_sequence_length()], samples=[0, 1]) + assert D.shape == (1, 2, 2) + with pytest.raises(TypeError): + ts.divergence_matrix(windoze=[0, 1]) + with pytest.raises(ValueError, match="at least 2"): + ts.divergence_matrix(windows=[0]) + with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): + ts.divergence_matrix(windows=[-1, 0, 1]) + with pytest.raises(ValueError): + ts.divergence_matrix(windows=[0, 1], samples="sdf") + with pytest.raises(ValueError, match="Unrecognised stats mode"): + ts.divergence_matrix(windows=[0, 1], mode="sdf") + with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"): + ts.divergence_matrix(windows=[0, 1], mode="node") + def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): tables = _tskit.TableCollection(sequence_length=ts.get_sequence_length()) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 6f3b080ce5..34334e9be0 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -81,6 +81,8 @@ def insert_branch_mutations(ts, mutations_per_branch=1): Returns a copy of the specified tree sequence with a mutation on every branch in every tree. """ + if mutations_per_branch == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() @@ -146,23 +148,26 @@ def insert_discrete_time_mutations(ts, num_times=4, num_sites=10): return tables.tree_sequence() -def insert_branch_sites(ts): +def insert_branch_sites(ts, m=1): """ - Returns a copy of the specified tree sequence with a site on every branch + Returns a copy of the specified tree sequence with m sites on every branch of every tree. """ + if m == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() for tree in ts.trees(): left, right = tree.interval - delta = (right - left) / len(list(tree.nodes())) + delta = (right - left) / (m * len(list(tree.nodes()))) x = left for u in tree.nodes(): if tree.parent(u) != tskit.NULL: - site = tables.sites.add_row(position=x, ancestral_state="0") - tables.mutations.add_row(site=site, node=u, derived_state="1") - x += delta + for _ in range(m): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, node=u, derived_state="1") + x += delta add_provenance(tables.provenances, "insert_branch_sites") return tables.tree_sequence() @@ -1774,7 +1779,6 @@ def update_counts(edge, left, sign): def genealogical_nearest_neighbours(ts, focal, reference_sets): - reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1 for k, reference_set in enumerate(reference_sets): for u in reference_set: diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9ccae3488d..1c26956494 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7695,8 +7695,80 @@ def divergence( span_normalise=span_normalise, ) - # JK: commenting this out for now to get the other methods well tested. - # Issue: https://github.com/tskit-dev/tskit/issues/201 + ############################################ + # Pairwise sample x sample statistics + ############################################ + + def _chunk_sequence_by_tree(self, num_chunks): + """ + Return list of (left, right) genome interval tuples that contain + approximately equal numbers of trees as a 2D numpy array. A + maximum of self.num_trees single-tree intervals can be returned. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(self.num_trees, num_chunks) + breakpoints = self.breakpoints(as_array=True)[:-1] + splits = np.array_split(breakpoints, num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunks.append((splits[j][0], splits[j + 1][0])) + chunks.append((splits[-1][0], self.sequence_length)) + return chunks + + @staticmethod + def _chunk_windows(windows, num_chunks): + """ + Returns a list of (at most) num_chunks windows, which represent splitting + up the specified list of windows into roughly equal work. + + Currently this is implemented by just splitting up into roughly equal + numbers of windows in each chunk. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(len(windows) - 1, num_chunks) + splits = np.array_split(windows[:-1], num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunk = np.append(splits[j], splits[j + 1][0]) + chunks.append(chunk) + chunk = np.append(splits[-1], windows[-1]) + chunks.append(chunk) + return chunks + + def _parallelise_divmat_by_tree(self, num_threads, **kwargs): + """ + No windows were specified, so we can chunk up the whole genome by + tree, and do a simple sum of the results. + """ + + def worker(interval): + return self._ll_tree_sequence.divergence_matrix(interval, **kwargs) + + work = self._chunk_sequence_by_tree(num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool: + results = pool.map(worker, work) + return sum(results) + + def _parallelise_divmat_by_window(self, windows, num_threads, **kwargs): + """ + We assume we have a number of windows that's >= to the number + of threads available, and let each thread have a chunk of the + windows. There will definitely cases where this leads to + pathological behaviour, so we may need a more sophisticated + strategy at some point. + """ + + def worker(sub_windows): + return self._ll_tree_sequence.divergence_matrix(sub_windows, **kwargs) + + work = self._chunk_windows(windows, num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, sub_windows) for sub_windows in work] + concurrent.futures.wait(futures) + return np.vstack([future.result() for future in futures]) + # def divergence_matrix(self, sample_sets, windows=None, mode="site"): # """ # Finds the mean divergence between pairs of samples from each set of @@ -7730,6 +7802,36 @@ def divergence( # A[w, i, j] = A[w, j, i] = x[w][k] # k += 1 # return A + # NOTE: see older definition of divmat here, which may be useful when documenting + # this function. See https://github.com/tskit-dev/tskit/issues/2781 + def divergence_matrix( + self, *, windows=None, samples=None, num_threads=0, mode=None + ): + windows_specified = windows is not None + windows = [0, self.sequence_length] if windows is None else windows + + mode = "site" if mode is None else mode + + # NOTE: maybe we want to use a different default for num_threads here, just + # following the approach in GNN + if num_threads <= 0: + D = self._ll_tree_sequence.divergence_matrix( + windows, samples=samples, mode=mode + ) + else: + if windows_specified: + D = self._parallelise_divmat_by_window( + windows, num_threads, samples=samples, mode=mode + ) + else: + D = self._parallelise_divmat_by_tree( + num_threads, samples=samples, mode=mode + ) + + if not windows_specified: + # Drop the windows dimension + D = D[0] + return D def genetic_relatedness( self,