From 5eb173ab4e45a05b35b8cbbdcfbed83a0b3d1853 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 10 Jul 2023 16:21:15 +0100 Subject: [PATCH 1/3] Fix benchmark CI --- .github/workflows/tests.yml | 6 +++--- python/requirements/benchmark.txt | 9 +++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 python/requirements/benchmark.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2da829034e..473f04e893 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,11 +34,11 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' cache: 'pip' - cache-dependency-path: python/requirements/development.txt + cache-dependency-path: python/requirements/benchmark.txt - name: Install deps - run: pip install -r python/requirements/development.txt + run: pip install -r python/requirements/benchmark.txt - name: Build module run: | cd python diff --git a/python/requirements/benchmark.txt b/python/requirements/benchmark.txt new file mode 100644 index 0000000000..12a0be4060 --- /dev/null +++ b/python/requirements/benchmark.txt @@ -0,0 +1,9 @@ +click +psutil +tqdm +matplotlib +si-prefix +jsonschema +svgwrite +msprime +PyYAML \ No newline at end of file From a3b095f5d38357326185095c395a19bc6a13bdb4 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 17 Feb 2023 09:25:05 +0000 Subject: [PATCH 2/3] Divergence matrix tree-by-tree algorithms Implement the basic version of the divergence matrix operation using tree-by-tree algorithms, and provide interface for parallelising along the genome. --- c/tests/test_stats.c | 345 ++++++++++- c/tests/test_trees.c | 9 +- c/tests/testlib.c | 12 +- c/tests/testlib.h | 4 +- c/tskit/core.c | 9 + c/tskit/core.h | 10 + c/tskit/trees.c | 576 ++++++++++++++++- c/tskit/trees.h | 4 + python/_tskitmodule.c | 76 +++ python/tests/test_divmat.py | 1064 ++++++++++++++++++++++++++++++++ python/tests/test_highlevel.py | 4 +- python/tests/test_lowlevel.py | 20 + python/tests/tsutil.py | 20 +- python/tskit/trees.py | 106 +++- 14 files changed, 2230 insertions(+), 29 deletions(-) create mode 100644 python/tests/test_divmat.py diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 35991288d4..2d5bc97fd5 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -262,6 +262,48 @@ verify_mean_descendants(tsk_treeseq_t *ts) free(C); } +/* Check the divergence matrix by running against the stats API equivalent + * code. NOTE: this will not always be equal in site mode, because of a slightly + * different definition wrt to multiple mutations at a site. + */ +static void +verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t mode) +{ + int ret; + const tsk_size_t n = tsk_treeseq_get_num_samples(ts); + const tsk_id_t *samples = tsk_treeseq_get_samples(ts); + tsk_size_t sample_set_sizes[n]; + tsk_id_t index_tuples[2 * n * n]; + double D1[n * n], D2[n * n]; + tsk_size_t i, j, k; + + for (j = 0; j < n; j++) { + sample_set_sizes[j] = 1; + for (k = 0; k < n; k++) { + index_tuples[2 * (j * n + k)] = (tsk_id_t) j; + index_tuples[2 * (j * n + k) + 1] = (tsk_id_t) k; + } + } + ret = tsk_treeseq_divergence( + ts, n, sample_set_sizes, samples, n * n, index_tuples, 0, NULL, mode, D1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_divergence_matrix(ts, 0, NULL, 0, NULL, mode, D2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < n; j++) { + for (k = 0; k < n; k++) { + i = j * n + k; + /* printf("%d\t%d\t%f\t%f\n", (int) j, (int) k, D1[i], D2[i]); */ + if (j == k) { + CU_ASSERT_EQUAL(D2[i], 0); + } else { + CU_ASSERT_DOUBLE_EQUAL(D1[i], D2[i], 1E-6); + } + } + } +} + typedef struct { int call_count; int error_on; @@ -973,6 +1015,128 @@ test_single_tree_general_stat_errors(void) tsk_treeseq_free(&ts); } +static void +test_single_tree_divergence_matrix(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 6, 6, 2, 0, 6, 6, 6, 6, 0, 4, 6, 6, 4, 0 }; + double D_site[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_internal_samples(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D[16] = { 0, 2, 4, 3, 2, 0, 4, 3, 4, 4, 0, 1, 3, 3, 1, 0 }; + + const char *nodes = "1 0 -1 -1\n" /* 2.00┊ 6 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏━┻━┓ ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 5* ┊ */ + "0 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "1 1 -1 -1\n" /* 0 * * * 1 */ + "0 2 -1 -1\n"; + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n" + "0 1 6 4,5\n"; + /* One mutations per branch so we get the same as the branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n" + "0.5 A\n" + "0.6 A\n"; + const char *mutations = "0 0 T -1\n" + "1 1 T -1\n" + "2 2 T -1\n" + "3 3 T -1\n" + "4 4 T -1\n" + "5 5 T -1\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_multi_root(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 }; + double D_site[16] = { 0, 4, 6, 6, 4, 0, 6, 6, 6, 6, 0, 8, 6, 6, 8, 0 }; + + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" /* 2.00┊ 5 ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "0 2 -1 -1\n"; /* 0 * * * * 1 */ + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n"; + /* Two mutations per branch unit so we get twice branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n"; + const char *mutations = "0 0 B -1\n" + "0 0 C 0\n" + "1 1 B -1\n" + "1 1 C 2\n" + "2 2 B -1\n" + "2 2 C 4\n" + "2 2 D 5\n" + "2 2 E 6\n" + "3 3 B -1\n" + "3 3 C 8\n" + "3 3 D 9\n" + "3 3 E 10\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + tsk_treeseq_free(&ts); +} + static void test_paper_ex_ld(void) { @@ -1592,6 +1756,20 @@ test_paper_ex_afs(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + static void test_nonbinary_ex_ld(void) { @@ -1726,6 +1904,158 @@ test_ld_silent_mutations(void) free(base_ts); } +static void +test_simplest_divergence_matrix(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[4] = { 0, 2, 2, 0 }; + double D_site[4] = { 0, 0, 0, 0 }; + double result[4]; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SPAN_NORMALISE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_POLARISED_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); + + sample_ids[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + sample_ids[0] = 3; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_windows(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[8] = { 0, 1, 1, 0, 0, 1, 1, 0 }; + double D_site[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + double result[8]; + double windows[] = { 0, 0.5, 1 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_site, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 2, windows, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + + windows[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.45; + windows[2] = 1.5; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.55; + windows[2] = 1.0; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_internal_sample(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1, 2 }; + double result[9]; + double D_branch[9] = { 0, 2, 1, 2, 0, 1, 1, 1, 0 }; + double D_site[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_branch, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_site, result); + + tsk_treeseq_free(&ts); +} + +static void +test_multiroot_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, multiroot_ex_nodes, multiroot_ex_edges, NULL, + multiroot_ex_sites, multiroot_ex_mutations, NULL, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + int main(int argc, char **argv) { @@ -1745,6 +2075,11 @@ main(int argc, char **argv) test_single_tree_genealogical_nearest_neighbours }, { "test_single_tree_general_stat", test_single_tree_general_stat }, { "test_single_tree_general_stat_errors", test_single_tree_general_stat_errors }, + { "test_single_tree_divergence_matrix", test_single_tree_divergence_matrix }, + { "test_single_tree_divergence_matrix_internal_samples", + test_single_tree_divergence_matrix_internal_samples }, + { "test_single_tree_divergence_matrix_multi_root", + test_single_tree_divergence_matrix_multi_root }, { "test_paper_ex_ld", test_paper_ex_ld }, { "test_paper_ex_mean_descendants", test_paper_ex_mean_descendants }, @@ -1785,6 +2120,7 @@ main(int argc, char **argv) { "test_paper_ex_f4", test_paper_ex_f4 }, { "test_paper_ex_afs_errors", test_paper_ex_afs_errors }, { "test_paper_ex_afs", test_paper_ex_afs }, + { "test_paper_ex_divergence_matrix", test_paper_ex_divergence_matrix }, { "test_nonbinary_ex_ld", test_nonbinary_ex_ld }, { "test_nonbinary_ex_mean_descendants", test_nonbinary_ex_mean_descendants }, @@ -1798,6 +2134,13 @@ main(int argc, char **argv) { "test_ld_multi_mutations", test_ld_multi_mutations }, { "test_ld_silent_mutations", test_ld_silent_mutations }, + { "test_simplest_divergence_matrix", test_simplest_divergence_matrix }, + { "test_simplest_divergence_matrix_windows", + test_simplest_divergence_matrix_windows }, + { "test_simplest_divergence_matrix_internal_sample", + test_simplest_divergence_matrix_internal_sample }, + { "test_multiroot_divergence_matrix", test_multiroot_divergence_matrix }, + { NULL, NULL }, }; return test_main(tests, argc, argv); diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 94e33ee487..cceb11d6fd 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -5395,7 +5395,6 @@ test_simplify_keep_input_roots_multi_tree(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - tsk_treeseq_dump(&ts, "tmp.trees", 0); ret = tsk_treeseq_simplify( &ts, samples, 2, TSK_SIMPLIFY_KEEP_INPUT_ROOTS, &simplified, NULL); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -7801,7 +7800,7 @@ test_time_uncalibrated(void) tsk_size_t sample_set_sizes[] = { 2, 2 }; tsk_id_t samples[] = { 0, 1, 2, 3 }; tsk_size_t num_samples; - double result[10]; + double result[100]; double *W; double *sigma; @@ -7857,6 +7856,12 @@ test_time_uncalibrated(void) TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, sigma); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TIME_UNCALIBRATED); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, + TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_safe_free(W); tsk_safe_free(sigma); tsk_treeseq_free(&ts); diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 823068d136..043ae5ceab 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -966,6 +966,16 @@ tskit_suite_init(void) return CUE_SUCCESS; } +void +assert_arrays_almost_equal(tsk_size_t len, double *a, double *b) +{ + tsk_size_t j; + + for (j = 0; j < len; j++) { + CU_ASSERT_DOUBLE_EQUAL(a[j], b[j], 1e-9); + } +} + static int tskit_suite_cleanup(void) { diff --git a/c/tests/testlib.h b/c/tests/testlib.h index d042d60b55..69efb14781 100644 --- a/c/tests/testlib.h +++ b/c/tests/testlib.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2021 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -54,6 +54,8 @@ void parse_individuals(const char *text, tsk_individual_table_t *individual_tabl void unsort_edges(tsk_edge_table_t *edges, size_t start); +void assert_arrays_almost_equal(tsk_size_t len, double *a, double *b); + extern const char *single_tree_ex_nodes; extern const char *single_tree_ex_edges; extern const char *single_tree_ex_sites; diff --git a/c/tskit/core.c b/c/tskit/core.c index b1ea25badd..100cc78cad 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -466,6 +466,15 @@ tsk_strerror_internal(int err) ret = "Statistics using branch lengths cannot be calculated when time_units " "is 'uncalibrated'. (TSK_ERR_TIME_UNCALIBRATED)"; break; + case TSK_ERR_STAT_POLARISED_UNSUPPORTED: + ret = "The TSK_STAT_POLARISED option is not supported by this statistic. " + "(TSK_ERR_STAT_POLARISED_UNSUPPORTED)"; + break; + case TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED: + ret = "The TSK_STAT_SPAN_NORMALISE option is not supported by this " + "statistic. " + "(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)"; + break; /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: diff --git a/c/tskit/core.h b/c/tskit/core.h index b8b9f354ba..4d2c95212d 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -675,6 +675,16 @@ Statistics based on branch lengths were attempted when the ``time_units`` were ``uncalibrated``. */ #define TSK_ERR_TIME_UNCALIBRATED -910 +/** +The TSK_STAT_POLARISED option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_POLARISED_UNSUPPORTED -911 +/** +The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 4604579e0b..cd0ad36aa2 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1191,9 +1191,11 @@ tsk_treeseq_mean_descendants(const tsk_treeseq_t *self, * General stats framework ***********************************/ +#define TSK_REQUIRE_FULL_SPAN 1 + static int -tsk_treeseq_check_windows( - const tsk_treeseq_t *self, tsk_size_t num_windows, const double *windows) +tsk_treeseq_check_windows(const tsk_treeseq_t *self, tsk_size_t num_windows, + const double *windows, tsk_flags_t options) { int ret = TSK_ERR_BAD_WINDOWS; tsk_size_t j; @@ -1202,12 +1204,23 @@ tsk_treeseq_check_windows( ret = TSK_ERR_BAD_NUM_WINDOWS; goto out; } - /* TODO these restrictions can be lifted later if we want a specific interval. */ - if (windows[0] != 0) { - goto out; - } - if (windows[num_windows] != self->tables->sequence_length) { - goto out; + if (options & TSK_REQUIRE_FULL_SPAN) { + /* TODO the general stat code currently requires that we include the + * entire tree sequence span. This should be relaxed, so hopefully + * this branch (and the option) can be removed at some point */ + if (windows[0] != 0) { + goto out; + } + if (windows[num_windows] != self->tables->sequence_length) { + goto out; + } + } else { + if (windows[0] < 0) { + goto out; + } + if (windows[num_windows] > self->tables->sequence_length) { + goto out; + } } for (j = 0; j < num_windows; j++) { if (windows[j] >= windows[j + 1]) { @@ -1960,7 +1973,8 @@ tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -2468,7 +2482,7 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); - double default_windows[] = { 0, self->tables->sequence_length }; + const double default_windows[] = { 0, self->tables->sequence_length }; const tsk_size_t num_nodes = self->tables->nodes.num_rows; const tsk_size_t K = num_sample_sets + 1; tsk_size_t j, k, l, afs_size; @@ -2496,7 +2510,8 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -3331,7 +3346,7 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, } ret = tsk_treeseq_init( output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); - /* Once tsk_tree_init has returned ownership of tables is transferred */ + /* Once tsk_treeseq_init has returned ownership of tables is transferred */ tables = NULL; out: if (tables != NULL) { @@ -3460,6 +3475,20 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag * Tree * ======================================================== */ +/* Return the root for the specified node. + * NOTE: no bounds checking is done here. + */ +static tsk_id_t +tsk_tree_get_node_root(const tsk_tree_t *self, tsk_id_t u) +{ + const tsk_id_t *restrict parent = self->parent; + + while (parent[u] != TSK_NULL) { + u = parent[u]; + } + return u; +} + int TSK_WARN_UNUSED tsk_tree_init(tsk_tree_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) { @@ -6009,3 +6038,526 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, } return ret; } + +/* + * Divergence matrix + */ + +typedef struct { + /* Note it's a waste storing the triply linked tree here, but the code + * is written on the assumption of 1-based trees and the algorithm is + * frighteningly subtle, so it doesn't seem worth messing with it + * unless we really need to save some memory */ + tsk_id_t *parent; + tsk_id_t *child; + tsk_id_t *sib; + tsk_id_t *lambda; + tsk_id_t *pi; + tsk_id_t *tau; + tsk_id_t *beta; + tsk_id_t *alpha; +} sv_tables_t; + +static int +sv_tables_init(sv_tables_t *self, tsk_size_t n) +{ + int ret = 0; + + self->parent = tsk_malloc(n * sizeof(*self->parent)); + self->child = tsk_malloc(n * sizeof(*self->child)); + self->sib = tsk_malloc(n * sizeof(*self->sib)); + self->pi = tsk_malloc(n * sizeof(*self->pi)); + self->lambda = tsk_malloc(n * sizeof(*self->lambda)); + self->tau = tsk_malloc(n * sizeof(*self->tau)); + self->beta = tsk_malloc(n * sizeof(*self->beta)); + self->alpha = tsk_malloc(n * sizeof(*self->alpha)); + if (self->parent == NULL || self->child == NULL || self->sib == NULL + || self->lambda == NULL || self->tau == NULL || self->beta == NULL + || self->alpha == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +sv_tables_free(sv_tables_t *self) +{ + tsk_safe_free(self->parent); + tsk_safe_free(self->child); + tsk_safe_free(self->sib); + tsk_safe_free(self->lambda); + tsk_safe_free(self->pi); + tsk_safe_free(self->tau); + tsk_safe_free(self->beta); + tsk_safe_free(self->alpha); + return 0; +} +static void +sv_tables_reset(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + tsk_memset(self->parent, 0, n * sizeof(*self->parent)); + tsk_memset(self->child, 0, n * sizeof(*self->child)); + tsk_memset(self->sib, 0, n * sizeof(*self->sib)); + tsk_memset(self->pi, 0, n * sizeof(*self->pi)); + tsk_memset(self->lambda, 0, n * sizeof(*self->lambda)); + tsk_memset(self->tau, 0, n * sizeof(*self->tau)); + tsk_memset(self->beta, 0, n * sizeof(*self->beta)); + tsk_memset(self->alpha, 0, n * sizeof(*self->alpha)); +} + +static void +sv_tables_convert_tree(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + const tsk_id_t *restrict tsk_parent = tree->parent; + tsk_id_t *restrict child = self->child; + tsk_id_t *restrict parent = self->parent; + tsk_id_t *restrict sib = self->sib; + tsk_size_t j; + tsk_id_t u, v; + + for (j = 0; j < n - 1; j++) { + u = (tsk_id_t) j + 1; + v = tsk_parent[j] + 1; + sib[u] = child[v]; + child[v] = u; + parent[u] = v; + } +} + +#define LAMBDA 0 + +static void +sv_tables_build_index(sv_tables_t *self) +{ + const tsk_id_t *restrict child = self->child; + const tsk_id_t *restrict parent = self->parent; + const tsk_id_t *restrict sib = self->sib; + tsk_id_t *restrict lambda = self->lambda; + tsk_id_t *restrict pi = self->pi; + tsk_id_t *restrict tau = self->tau; + tsk_id_t *restrict beta = self->beta; + tsk_id_t *restrict alpha = self->alpha; + tsk_id_t a, n, p, h; + + p = child[LAMBDA]; + n = 0; + lambda[0] = -1; + while (p != LAMBDA) { + while (true) { + n++; + pi[p] = n; + tau[n] = LAMBDA; + lambda[n] = 1 + lambda[n >> 1]; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + beta[p] = n; + while (true) { + tau[beta[p]] = parent[p]; + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p != LAMBDA) { + h = lambda[n & -pi[p]]; + beta[p] = ((n >> h) | 1) << h; + } else { + break; + } + } + } + } + + /* Begin the second traversal */ + lambda[0] = lambda[n]; + pi[LAMBDA] = 0; + beta[LAMBDA] = 0; + alpha[LAMBDA] = 0; + p = child[LAMBDA]; + while (p != LAMBDA) { + while (true) { + a = alpha[parent[p]] | (beta[p] & -beta[p]); + alpha[p] = a; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + while (true) { + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p == LAMBDA) { + break; + } + } + } + } +} + +static void +sv_tables_build(sv_tables_t *self, tsk_tree_t *tree) +{ + sv_tables_reset(self, tree); + sv_tables_convert_tree(self, tree); + sv_tables_build_index(self); +} + +static tsk_id_t +sv_tables_mrca_one_based(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + const tsk_id_t *restrict lambda = self->lambda; + const tsk_id_t *restrict pi = self->pi; + const tsk_id_t *restrict tau = self->tau; + const tsk_id_t *restrict beta = self->beta; + const tsk_id_t *restrict alpha = self->alpha; + tsk_id_t h, k, xhat, yhat, ell, j, z; + + if (beta[x] <= beta[y]) { + h = lambda[beta[y] & -beta[x]]; + } else { + h = lambda[beta[x] & -beta[y]]; + } + k = alpha[x] & alpha[y] & -(1 << h); + h = lambda[k & -k]; + j = ((beta[x] >> h) | 1) << h; + if (j == beta[x]) { + xhat = x; + } else { + ell = lambda[alpha[x] & ((1 << h) - 1)]; + xhat = tau[((beta[x] >> ell) | 1) << ell]; + } + if (j == beta[y]) { + yhat = y; + } else { + ell = lambda[alpha[y] & ((1 << h) - 1)]; + yhat = tau[((beta[y] >> ell) | 1) << ell]; + } + if (pi[xhat] <= pi[yhat]) { + z = xhat; + } else { + z = yhat; + } + return z; +} + +static tsk_id_t +sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + /* Convert to 1-based indexes and back */ + return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1; +} + +static int +tsk_treeseq_check_node_bounds( + const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows; + + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + if (u < 0 || u >= N) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } +out: + return ret; +} + +static int +tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t options, double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const double *restrict nodes_time = self->tables->nodes.time; + const tsk_size_t n = num_samples; + tsk_size_t i, j, k; + tsk_id_t u, v, w, u_root, v_root; + double tu, tv, d, span, left, right, span_left, span_right; + double *restrict D; + sv_tables_t sv; + + memset(&sv, 0, sizeof(sv)); + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + ret = sv_tables_init(&sv, self->tables->nodes.num_rows + 1); + if (ret != 0) { + goto out; + } + + if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { + ret = TSK_ERR_TIME_UNCALIBRATED; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + span = span_right - span_left; + sv_tables_build(&sv, &tree); + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + w = sv_tables_mrca(&sv, u, v); + if (w != TSK_NULL) { + u_root = w; + v_root = w; + } else { + /* Slow path - only happens for nodes in disconnected + * subtrees in a tree with multiple roots */ + u_root = tsk_tree_get_node_root(&tree, u); + v_root = tsk_tree_get_node_root(&tree, v); + } + tu = nodes_time[u_root] - nodes_time[u]; + tv = nodes_time[v_root] - nodes_time[v]; + d = (tu + tv) * span; + D[j * n + k] += d; + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + sv_tables_free(&sv); + return ret; +} + +static tsk_size_t +count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, + const double *restrict time, const tsk_size_t *restrict mutations_per_node) +{ + double tu, tv; + tsk_size_t count = 0; + + tu = time[u]; + tv = time[v]; + while (u != v) { + if (tu < tv) { + count += mutations_per_node[u]; + u = parent[u]; + if (u == TSK_NULL) { + break; + } + tu = time[u]; + } else { + count += mutations_per_node[v]; + v = parent[v]; + if (v == TSK_NULL) { + break; + } + tv = time[v]; + } + } + if (u != v) { + while (u != TSK_NULL) { + count += mutations_per_node[u]; + u = parent[u]; + } + while (v != TSK_NULL) { + count += mutations_per_node[v]; + v = parent[v]; + } + } + return count; +} + +static int +tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t TSK_UNUSED(options), + double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const tsk_size_t n = num_samples; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const double *restrict nodes_time = self->tables->nodes.time; + tsk_size_t i, j, k, tree_site, tree_mut; + tsk_site_t site; + tsk_mutation_t mut; + tsk_id_t u, v; + double left, right, span_left, span_right; + double *restrict D; + tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node)); + + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + if (mutations_per_node == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + + /* NOTE: we could avoid this full memset across all nodes by doing + * the same loops again and decrementing at the end of the main + * tree-loop. It's probably not worth it though, because of the + * overwhelming O(n^2) below */ + tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node)); + for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { + site = tree.sites[tree_site]; + if (span_left <= site.position && site.position < span_right) { + for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { + mut = site.mutations[tree_mut]; + mutations_per_node[mut.node]++; + } + } + } + + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + D[j * n + k] += (double) count_mutations_on_path( + u, v, tree.parent, nodes_time, mutations_per_node); + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + tsk_safe_free(mutations_per_node); + return ret; +} + +static void +fill_lower_triangle( + double *restrict result, const tsk_size_t n, const tsk_size_t num_windows) +{ + tsk_size_t i, j, k; + double *restrict D; + + /* TODO there's probably a better striding pattern that could be used here */ + for (i = 0; i < num_windows; i++) { + D = result + i * n * n; + for (j = 0; j < n; j++) { + for (k = j + 1; k < n; k++) { + D[k * n + j] = D[j * n + k]; + } + } + } +} + +int +tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result) +{ + int ret = 0; + const tsk_id_t *samples = self->samples; + tsk_size_t n = self->num_samples; + const double default_windows[] = { 0, self->tables->sequence_length }; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + bool stat_node = !!(options & TSK_STAT_NODE); + + if (stat_node) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } + /* If no mode is specified, we default to site mode */ + if (!(stat_site || stat_branch)) { + stat_site = true; + } + /* It's an error to specify more than one mode */ + if (stat_site + stat_branch > 1) { + ret = TSK_ERR_MULTIPLE_STAT_MODES; + goto out; + } + + if (options & TSK_STAT_POLARISED) { + ret = TSK_ERR_STAT_POLARISED_UNSUPPORTED; + goto out; + } + if (options & TSK_STAT_SPAN_NORMALISE) { + ret = TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED; + goto out; + } + + if (windows == NULL) { + num_windows = 1; + windows = default_windows; + } else { + ret = tsk_treeseq_check_windows(self, num_windows, windows, 0); + if (ret != 0) { + goto out; + } + } + + if (samples_in != NULL) { + samples = samples_in; + n = num_samples; + ret = tsk_treeseq_check_node_bounds(self, n, samples); + if (ret != 0) { + goto out; + } + } + + tsk_memset(result, 0, num_windows * n * n * sizeof(*result)); + + if (stat_branch) { + ret = tsk_treeseq_divergence_matrix_branch( + self, n, samples, num_windows, windows, options, result); + } else { + tsk_bug_assert(stat_site); + ret = tsk_treeseq_divergence_matrix_site( + self, n, samples, num_windows, windows, options, result); + } + if (ret != 0) { + goto out; + } + fill_lower_triangle(result, n, num_windows); + +out: + return ret; +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index efe9980077..10c820e1c0 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1003,6 +1003,10 @@ int tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result); + /****************************************************************************/ /* Tree */ /****************************************************************************/ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index dea3c03fd9..8d42b50afc 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9638,6 +9638,78 @@ TreeSequence_f4(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_stat_method(self, args, kwds, 4, tsk_treeseq_f4); } +static PyObject * +TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "windows", "samples", "mode", NULL }; + PyArrayObject *result_array = NULL; + PyObject *windows = NULL; + PyObject *py_samples = Py_None; + char *mode = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *samples_array = NULL; + tsk_flags_t options = 0; + npy_intp *shape, dims[3]; + tsk_size_t num_samples, num_windows; + tsk_id_t *samples = NULL; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|Os", kwlist, &windows, &py_samples, &mode)) { + goto out; + } + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + if (py_samples != Py_None) { + samples_array = (PyArrayObject *) PyArray_FROMANY( + py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + samples = PyArray_DATA(samples_array); + num_samples = (tsk_size_t) shape[0]; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + dims[0] = num_windows; + dims[1] = num_samples; + dims[2] = num_samples; + result_array = (PyArrayObject *) PyArray_SimpleNew(3, dims, NPY_FLOAT64); + if (result_array == NULL) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = tsk_treeseq_divergence_matrix( + self->tree_sequence, + num_samples, samples, + num_windows, PyArray_DATA(windows_array), + options, PyArray_DATA(result_array)); + Py_END_ALLOW_THREADS + // clang-format on + /* Clang-format insists on doing this in spite of the "off" instruction above */ + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(result_array); + Py_XDECREF(windows_array); + Py_XDECREF(samples_array); + return ret; +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -10346,6 +10418,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_f4, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the f4 statistic." }, + { .ml_name = "divergence_matrix", + .ml_meth = (PyCFunction) TreeSequence_divergence_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the pairwise divergence matrix." }, { .ml_name = "split_edges", .ml_meth = (PyCFunction) TreeSequence_split_edges, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py new file mode 100644 index 0000000000..acb2403d41 --- /dev/null +++ b/python/tests/test_divmat.py @@ -0,0 +1,1064 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for divergence matrix based pairwise stats +""" +import collections + +import msprime +import numpy as np +import pytest + +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + +DIVMAT_MODES = ["branch", "site"] + +# NOTE: this implementation of Schieber-Vishkin algorithm is done like +# this so it's easy to run with numba. It would be more naturally +# packaged as a class. We don't actually use numba here, but it's +# handy to have a version of the SV code lying around that can be +# run directly with numba. + + +def sv_tables_init(parent_array): + n = 1 + parent_array.shape[0] + + LAMBDA = 0 + # Triply-linked tree. FIXME we shouldn't need to build this as it's + # available already in tskit + child = np.zeros(n, dtype=np.int32) + parent = np.zeros(n, dtype=np.int32) + sib = np.zeros(n, dtype=np.int32) + + for j in range(n - 1): + u = j + 1 + v = parent_array[j] + 1 + sib[u] = child[v] + child[v] = u + parent[u] = v + + lambd = np.zeros(n, dtype=np.int32) + pi = np.zeros(n, dtype=np.int32) + tau = np.zeros(n, dtype=np.int32) + beta = np.zeros(n, dtype=np.int32) + alpha = np.zeros(n, dtype=np.int32) + + p = child[LAMBDA] + n = 0 + lambd[0] = -1 + while p != LAMBDA: + while True: + n += 1 + pi[p] = n + tau[n] = LAMBDA + lambd[n] = 1 + lambd[n >> 1] + if child[p] != LAMBDA: + p = child[p] + else: + break + beta[p] = n + while True: + tau[beta[p]] = parent[p] + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p != LAMBDA: + h = lambd[n & -pi[p]] + beta[p] = ((n >> h) | 1) << h + else: + break + + # Begin the second traversal + lambd[0] = lambd[n] + pi[LAMBDA] = 0 + beta[LAMBDA] = 0 + alpha[LAMBDA] = 0 + p = child[LAMBDA] + while p != LAMBDA: + while True: + a = alpha[parent[p]] | (beta[p] & -beta[p]) + alpha[p] = a + if child[p] != LAMBDA: + p = child[p] + else: + break + while True: + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p == LAMBDA: + break + + return lambd, pi, tau, beta, alpha + + +def _sv_mrca(x, y, lambd, pi, tau, beta, alpha): + if beta[x] <= beta[y]: + h = lambd[beta[y] & -beta[x]] + else: + h = lambd[beta[x] & -beta[y]] + k = alpha[x] & alpha[y] & -(1 << h) + h = lambd[k & -k] + j = ((beta[x] >> h) | 1) << h + if j == beta[x]: + xhat = x + else: + ell = lambd[alpha[x] & ((1 << h) - 1)] + xhat = tau[((beta[x] >> ell) | 1) << ell] + if j == beta[y]: + yhat = y + else: + ell = lambd[alpha[y] & ((1 << h) - 1)] + yhat = tau[((beta[y] >> ell) | 1) << ell] + if pi[xhat] <= pi[yhat]: + z = xhat + else: + z = yhat + return z + + +def sv_mrca(x, y, lambd, pi, tau, beta, alpha): + # Convert to 1-based indexes + return _sv_mrca(x + 1, y + 1, lambd, pi, tau, beta, alpha) - 1 + + +def local_root(tree, u): + while tree.parent(u) != tskit.NULL: + u = tree.parent(u) + return u + + +def branch_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + # print(f"WINDOW {i} [{left}, {right})") + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + span = span_right - span_left + # print(f"\ttree {tree.interval} [{span_left}, {span_right})") + tables = sv_tables_init(tree.parent_array) + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = sv_mrca(u, v, *tables) + assert w == tree.mrca(u, v) + if w != tskit.NULL: + tu = ts.nodes_time[w] - ts.nodes_time[u] + tv = ts.nodes_time[w] - ts.nodes_time[v] + else: + tu = ts.nodes_time[local_root(tree, u)] - ts.nodes_time[u] + tv = ts.nodes_time[local_root(tree, v)] - ts.nodes_time[v] + d = (tu + tv) * span + D[i, j, k] += d + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def divergence_matrix(ts, windows=None, samples=None, mode="site"): + assert mode in ["site", "branch"] + if mode == "site": + return site_divergence_matrix(ts, samples=samples, windows=windows) + else: + return branch_divergence_matrix(ts, samples=samples, windows=windows) + + +def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="site"): + samples = ts.samples() if samples is None else samples + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else list(windows) + num_windows = len(windows) - 1 + + if len(samples) == 0: + # FIXME: the code general stat code doesn't seem to handle zero samples + # case, need to identify MWE and file issue. + if windows_specified: + return np.zeros(shape=(num_windows, 0, 0)) + else: + return np.zeros(shape=(0, 0)) + + # Make sure that all the specified samples have the sample flag set, otherwise + # the library code will complain + tables = ts.dump_tables() + flags = tables.nodes.flags + # NOTE: this is a shortcut, setting all flags unconditionally to zero, so don't + # use this tree sequence outside this method. + flags[:] = 0 + flags[samples] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + + # FIXME We have to go through this annoying rigmarole because windows must start and + # end with 0 and L. We should relax this requirement to just making the windows + # contiguous, so that we just look at specific sections of the genome. + drop = [] + if windows[0] != 0: + windows = [0] + windows + drop.append(0) + if windows[-1] != ts.sequence_length: + windows.append(ts.sequence_length) + drop.append(-1) + + n = len(samples) + sample_sets = [[u] for u in samples] + indexes = [(i, j) for i in range(n) for j in range(n)] + X = ts.divergence( + sample_sets, + indexes=indexes, + mode=mode, + span_normalise=False, + windows=windows, + ) + keep = np.ones(len(windows) - 1, dtype=bool) + keep[drop] = False + X = X[keep] + out = X.reshape((X.shape[0], n, n)) + for D in out: + np.fill_diagonal(D, 0) + if not windows_specified: + out = out[0] + return out + + +def rootward_path(tree, u, v): + while u != v: + yield u + u = tree.parent(u) + + +def site_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + mutations_per_node = collections.Counter() + for site in tree.sites(): + if span_left <= site.position < span_right: + for mutation in site.mutations: + mutations_per_node[mutation.node] += 1 + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = tree.mrca(u, v) + if w != tskit.NULL: + wu = w + wv = w + else: + wu = local_root(tree, u) + wv = local_root(tree, v) + du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) + dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) + # NOTE: we're just accumulating the raw mutation counts, not + # multiplying by span + D[i, j, k] += du + dv + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def check_divmat( + ts, + *, + windows=None, + samples=None, + verbosity=0, + compare_stats_api=True, + compare_lib=True, + mode="site", +): + np.set_printoptions(linewidth=500, precision=4) + # print(ts.draw_text()) + if verbosity > 1: + print(ts.draw_text()) + + D1 = divergence_matrix(ts, windows=windows, samples=samples, mode=mode) + if compare_stats_api: + # Somethings like duplicate samples aren't worth hacking around for in + # stats API. + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + # print("windows = ", windows) + # print(D1) + # print(D2) + np.testing.assert_allclose(D1, D2) + assert D1.shape == D2.shape + if compare_lib: + D3 = ts.divergence_matrix(windows=windows, samples=samples, mode=mode) + # print(D3) + assert D1.shape == D3.shape + np.testing.assert_allclose(D1, D3) + return D1 + + +class TestExamplesWithAnswer: + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples(self, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + D = check_divmat(ts, samples=[], mode="site") + assert D.shape == (0, 0) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples_windows(self, num_windows, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D = check_divmat(ts, samples=[], windows=windows, mode="site") + assert D.shape == (num_windows, 0, 0) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_sites_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts, m) + D1 = check_divmat(ts, mode="site") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_mutations_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_mutations(ts, m) + # The stats API will produce a different value here, because + # we're just counting up the mutations and not reasoning about + # the state of samples at all. + D1 = check_divmat(ts, mode="site", compare_stats_api=False) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("L", [0.1, 1, 2, 100]) + def test_single_tree_sequence_length(self, L): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=L).tree_sequence + D1 = check_divmat(ts, mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, L * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_gap_at_end(self, num_windows, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 + # 0 1 2 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + tables = ts.dump_tables() + tables.sequence_length = 2 + ts = tables.tree_sequence() + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D1 = check_divmat(ts, windows=windows, mode=mode) + D1 = np.sum(D1, axis=0) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_subset_permuted_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[1, 2, 0], mode=mode) + D2 = np.array( + [ + [0.0, 4.0, 2.0], + [4.0, 0.0, 4.0], + [2.0, 4.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_mixed_non_sample_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 5], mode=mode) + D2 = np.array( + [ + [0.0, 3.0], + [3.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_duplicate_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 0, 1], compare_stats_api=False, mode=mode) + D2 = np.array( + [ + [0.0, 0.0, 2.0], + [0.0, 0.0, 2.0], + [2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_multiroot(self, mode): + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + ts = ts.decapitate(1) + D1 = check_divmat(ts, mode=mode) + D2 = np.array( + [ + [0.0, 2.0, 2.0, 2.0], + [2.0, 0.0, 2.0, 2.0], + [2.0, 2.0, 0.0, 2.0], + [2.0, 2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize( + ["left", "right"], [(0, 10), (1, 3), (3.25, 3.75), (5, 10)] + ) + def test_single_tree_interval(self, left, right): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + D1 = check_divmat(ts, windows=[left, right], mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1[0], (right - left) * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5, 11]) + def test_single_tree_equal_windows(self, num_windows): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + x = ts.sequence_length / num_windows + # print(windows) + D1 = check_divmat(ts, windows=windows, mode="branch") + assert D1.shape == (num_windows, 4, 4) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + for D in D1: + np.testing.assert_array_almost_equal(D, x * D2) + + @pytest.mark.parametrize("n", [2, 3, 5]) + def test_single_tree_no_sites(self, n): + ts = tskit.Tree.generate_balanced(n, span=10).tree_sequence + D = check_divmat(ts, mode="site") + np.testing.assert_array_equal(D, np.zeros((n, n))) + + +class TestExamples: + @pytest.mark.parametrize( + "interval", [(0, 26), (1, 3), (3.25, 13.75), (5, 10), (25.5, 26)] + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_interval(self, interval, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + check_divmat(ts, windows=interval, mode=mode) + + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows(self, windows, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + D = check_divmat(ts, windows=windows, mode=mode) + assert D.shape == (len(windows) - 1, 4, 4) + + @pytest.mark.parametrize("num_windows", [1, 5, 28]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows_gap_at_end(self, num_windows, mode): + tables = tsutil.all_trees_ts(4).dump_tables() + tables.sequence_length = 30 + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + assert ts.last().num_roots == 4 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5]) + @pytest.mark.parametrize("seed", range(1, 4)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_small_sims(self, n, seed, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + sequence_length=1000, + recombination_rate=0.01, + random_seed=seed, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, rate=0.1, discrete_genome=False, random_seed=seed + ) + assert ts.num_mutations > 1 + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_windows", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_sims_windows(self, n, num_windows, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=79234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, + rate=0.01, + discrete_genome=False, + random_seed=1234, + ) + assert ts.num_mutations >= 2 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_balanced_tree(self, n, mode): + ts = tskit.Tree.generate_balanced(n).tree_sequence + ts = tsutil.insert_branch_sites(ts) + # print(ts.draw_text()) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_internal_sample(self, mode): + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[3] = 0 + flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("seed", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_one_internal_sample_sims(self, seed, mode): + ts = msprime.sim_ancestry( + 10, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=seed, + ) + t = ts.dump_tables() + # Add a new sample directly below another sample + u = t.nodes.add_row(time=-1, flags=tskit.NODE_IS_SAMPLE) + t.edges.add_row(parent=0, child=u, left=0, right=ts.sequence_length) + t.sort() + t.build_index() + ts = t.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_missing_flanks(self, mode): + ts = msprime.sim_ancestry( + 20, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = ts.keep_intervals([[20, 80]]) + assert ts.first().interval == (0, 20) + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_samples(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in ts1.samples(): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_all(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in range(ts1.num_nodes): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_disconnected_non_sample_topology(self, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(5).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + # Add an extra bit of disconnected non-sample topology + u = tables.nodes.add_row(time=0) + v = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=v, child=u) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + +class TestSuiteExamples: + """ + Compare the stats API method vs the library implementation for the + suite test examples. Some of these examples are too large to run the + Python code above on. + """ + + def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): + D1 = ts.divergence_matrix( + windows=windows, + samples=samples, + num_threads=num_threads, + mode=mode, + ) + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + assert D1.shape == D2.shape + if mode == "branch": + # If we have missing data then parts of the divmat are defined to be zero, + # so relative tolerances aren't useful. Because the stats API + # method necessarily involves subtracting away all of the previous + # values for an empty tree, there is a degree of numerical imprecision + # here. This value for atol is what is needed to get the tests to + # pass in practise. + has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees()) + atol = 1e-12 if has_missing_data else 0 + np.testing.assert_allclose(D1, D2, atol=atol) + else: + assert mode == "site" + if np.any(ts.mutations_parent != tskit.NULL): + # The stats API computes something slightly different when we have + # recurrent mutations, so fall back to the naive version. + D2 = site_divergence_matrix(ts, windows=windows, samples=samples) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_defaults(self, ts, mode): + self.check(ts, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_subset_samples(self, ts, mode): + n = min(ts.num_samples, 2) + self.check(ts, samples=ts.samples()[:n], mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=13) + self.check(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_no_windows(self, ts, mode): + self.check(ts, num_threads=5, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=11) + self.check(ts, num_threads=5, windows=windows, mode=mode) + + +class TestThreadsNoWindows: + def check(self, ts, num_threads, samples=None, mode=None): + D1 = ts.divergence_matrix(num_threads=0, samples=samples, mode=mode) + D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, 2, samples, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, n, num_threads, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + self.check(ts, num_threads, mode=mode) + + +class TestThreadsWindows: + def check(self, ts, num_threads, *, windows, samples=None, mode=None): + D1 = ts.divergence_matrix( + num_threads=0, windows=windows, samples=samples, mode=mode + ) + D2 = ts.divergence_matrix( + num_threads=num_threads, windows=windows, samples=samples, mode=mode + ) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, windows, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, windows=windows, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + (None,), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, windows, mode): + ts = tsutil.all_trees_ts(4) + self.check(ts, 2, windows=windows, samples=samples, mode=mode) + + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 100],), + ([0, 50, 75, 95, 100],), + ([50, 75, 95, 100],), + ([0, 50, 75, 95],), + (list(range(100)),), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, num_threads, windows, mode): + ts = msprime.sim_ancestry( + 15, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1234) + assert ts.num_mutations > 10 + self.check(ts, num_threads, windows=windows, mode=mode) + + +# NOTE these are tests that are for more general functionality that might +# get applied across many different functions, and so probably should be +# tested in another file. For now they're only used by divmat, so we can +# keep them here for simplificity. +class TestChunkByTree: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 26]]), + (2, [[0, 13], [13, 26]]), + (3, [[0, 9], [9, 18], [18, 26]]), + (4, [[0, 7], [7, 14], [14, 20], [20, 26]]), + (5, [[0, 6], [6, 11], [11, 16], [16, 21], [21, 26]]), + ], + ) + def test_all_trees_ts_26(self, num_chunks, expected): + ts = tsutil.all_trees_ts(4) + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4(self, num_chunks, expected): + ts = tsutil.all_trees_ts(3) + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("span", [1, 2, 5, 0.3]) + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4_trees_span(self, span, num_chunks, expected): + tables = tsutil.all_trees_ts(3).dump_tables() + tables.edges.left *= span + tables.edges.right *= span + tables.sequence_length *= span + ts = tables.tree_sequence() + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, np.array(expected) * span) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_empty_ts(self, num_chunks): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, 1]]) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_single_tree(self, num_chunks): + L = 10 + ts = tskit.Tree.generate_balanced(2, span=L).tree_sequence + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, L]]) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + ts = tskit.Tree.generate_balanced(2).tree_sequence + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + ts._chunk_sequence_by_tree(num_chunks) + + +class TestChunkWindows: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["windows", "num_chunks", "expected"], + [ + ([0, 10], 1, [[0, 10]]), + ([0, 10], 2, [[0, 10]]), + ([0, 5, 10], 2, [[0, 5], [5, 10]]), + ([0, 5, 6, 10], 2, [[0, 5, 6], [6, 10]]), + ([0, 5, 6, 10], 3, [[0, 5], [5, 6], [6, 10]]), + ], + ) + def test_examples(self, windows, num_chunks, expected): + actual = tskit.TreeSequence._chunk_windows(windows, num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + tskit.TreeSequence._chunk_windows([0, 1], num_chunks) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index ce225f1dd7..0529dda001 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -228,14 +228,14 @@ def get_gap_examples(): assert len(t.parent_dict) == 0 found = True assert found - ret.append((f"gap {x}", ts)) + ret.append((f"gap_{x}", ts)) # Give an example with a gap at the end. ts = msprime.simulate(10, random_seed=5, recombination_rate=1) tables = get_table_collection_copy(ts.dump_tables(), 2) tables.sites.clear() tables.mutations.clear() insert_uniform_mutations(tables, 100, list(ts.samples())) - ret.append(("gap at end", tables.tree_sequence())) + ret.append(("gap_at_end", tables.tree_sequence())) return ret diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 70ef08143c..530ec7223f 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1529,6 +1529,26 @@ def test_kc_distance(self): x2 = ts2.get_kc_distance(ts1, lambda_) assert x1 == x2 + def test_divergence_matrix(self): + n = 10 + ts = self.get_example_tree_sequence(n, random_seed=12) + D = ts.divergence_matrix([0, ts.get_sequence_length()]) + assert D.shape == (1, n, n) + D = ts.divergence_matrix([0, ts.get_sequence_length()], samples=[0, 1]) + assert D.shape == (1, 2, 2) + with pytest.raises(TypeError): + ts.divergence_matrix(windoze=[0, 1]) + with pytest.raises(ValueError, match="at least 2"): + ts.divergence_matrix(windows=[0]) + with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): + ts.divergence_matrix(windows=[-1, 0, 1]) + with pytest.raises(ValueError): + ts.divergence_matrix(windows=[0, 1], samples="sdf") + with pytest.raises(ValueError, match="Unrecognised stats mode"): + ts.divergence_matrix(windows=[0, 1], mode="sdf") + with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"): + ts.divergence_matrix(windows=[0, 1], mode="node") + def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): tables = _tskit.TableCollection(sequence_length=ts.get_sequence_length()) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 6f3b080ce5..34334e9be0 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -81,6 +81,8 @@ def insert_branch_mutations(ts, mutations_per_branch=1): Returns a copy of the specified tree sequence with a mutation on every branch in every tree. """ + if mutations_per_branch == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() @@ -146,23 +148,26 @@ def insert_discrete_time_mutations(ts, num_times=4, num_sites=10): return tables.tree_sequence() -def insert_branch_sites(ts): +def insert_branch_sites(ts, m=1): """ - Returns a copy of the specified tree sequence with a site on every branch + Returns a copy of the specified tree sequence with m sites on every branch of every tree. """ + if m == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() for tree in ts.trees(): left, right = tree.interval - delta = (right - left) / len(list(tree.nodes())) + delta = (right - left) / (m * len(list(tree.nodes()))) x = left for u in tree.nodes(): if tree.parent(u) != tskit.NULL: - site = tables.sites.add_row(position=x, ancestral_state="0") - tables.mutations.add_row(site=site, node=u, derived_state="1") - x += delta + for _ in range(m): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, node=u, derived_state="1") + x += delta add_provenance(tables.provenances, "insert_branch_sites") return tables.tree_sequence() @@ -1774,7 +1779,6 @@ def update_counts(edge, left, sign): def genealogical_nearest_neighbours(ts, focal, reference_sets): - reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1 for k, reference_set in enumerate(reference_sets): for u in reference_set: diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9ccae3488d..1c26956494 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7695,8 +7695,80 @@ def divergence( span_normalise=span_normalise, ) - # JK: commenting this out for now to get the other methods well tested. - # Issue: https://github.com/tskit-dev/tskit/issues/201 + ############################################ + # Pairwise sample x sample statistics + ############################################ + + def _chunk_sequence_by_tree(self, num_chunks): + """ + Return list of (left, right) genome interval tuples that contain + approximately equal numbers of trees as a 2D numpy array. A + maximum of self.num_trees single-tree intervals can be returned. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(self.num_trees, num_chunks) + breakpoints = self.breakpoints(as_array=True)[:-1] + splits = np.array_split(breakpoints, num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunks.append((splits[j][0], splits[j + 1][0])) + chunks.append((splits[-1][0], self.sequence_length)) + return chunks + + @staticmethod + def _chunk_windows(windows, num_chunks): + """ + Returns a list of (at most) num_chunks windows, which represent splitting + up the specified list of windows into roughly equal work. + + Currently this is implemented by just splitting up into roughly equal + numbers of windows in each chunk. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(len(windows) - 1, num_chunks) + splits = np.array_split(windows[:-1], num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunk = np.append(splits[j], splits[j + 1][0]) + chunks.append(chunk) + chunk = np.append(splits[-1], windows[-1]) + chunks.append(chunk) + return chunks + + def _parallelise_divmat_by_tree(self, num_threads, **kwargs): + """ + No windows were specified, so we can chunk up the whole genome by + tree, and do a simple sum of the results. + """ + + def worker(interval): + return self._ll_tree_sequence.divergence_matrix(interval, **kwargs) + + work = self._chunk_sequence_by_tree(num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool: + results = pool.map(worker, work) + return sum(results) + + def _parallelise_divmat_by_window(self, windows, num_threads, **kwargs): + """ + We assume we have a number of windows that's >= to the number + of threads available, and let each thread have a chunk of the + windows. There will definitely cases where this leads to + pathological behaviour, so we may need a more sophisticated + strategy at some point. + """ + + def worker(sub_windows): + return self._ll_tree_sequence.divergence_matrix(sub_windows, **kwargs) + + work = self._chunk_windows(windows, num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, sub_windows) for sub_windows in work] + concurrent.futures.wait(futures) + return np.vstack([future.result() for future in futures]) + # def divergence_matrix(self, sample_sets, windows=None, mode="site"): # """ # Finds the mean divergence between pairs of samples from each set of @@ -7730,6 +7802,36 @@ def divergence( # A[w, i, j] = A[w, j, i] = x[w][k] # k += 1 # return A + # NOTE: see older definition of divmat here, which may be useful when documenting + # this function. See https://github.com/tskit-dev/tskit/issues/2781 + def divergence_matrix( + self, *, windows=None, samples=None, num_threads=0, mode=None + ): + windows_specified = windows is not None + windows = [0, self.sequence_length] if windows is None else windows + + mode = "site" if mode is None else mode + + # NOTE: maybe we want to use a different default for num_threads here, just + # following the approach in GNN + if num_threads <= 0: + D = self._ll_tree_sequence.divergence_matrix( + windows, samples=samples, mode=mode + ) + else: + if windows_specified: + D = self._parallelise_divmat_by_window( + windows, num_threads, samples=samples, mode=mode + ) + else: + D = self._parallelise_divmat_by_tree( + num_threads, samples=samples, mode=mode + ) + + if not windows_specified: + # Drop the windows dimension + D = D[0] + return D def genetic_relatedness( self, From 385d6d2eab1a58dd30d1d7d596da1260b85341e4 Mon Sep 17 00:00:00 2001 From: astheeggeggs Date: Wed, 31 Aug 2022 16:57:54 +0100 Subject: [PATCH 3/3] Add backwards algorithm for haploid data, using lshmm for testing --- .github/workflows/tests.yml | 2 +- .../CI-tests-pip/requirements.txt | 2 +- python/tests/test_genotype_matching_fb.py | 1 - python/tests/test_haplotype_matching.py | 1392 +++++++---------- 4 files changed, 603 insertions(+), 794 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 473f04e893..58eb68a9d3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -83,7 +83,7 @@ jobs: /usr/share/miniconda/envs/anaconda-client-env ~/osx-conda ~/.profile - key: ${{ runner.os }}-${{ matrix.python}}-conda-v11-${{ hashFiles('python/requirements/CI-tests-conda/requirements.txt') }}-${{ hashFiles('python/requirements/CI-tests-pip/requirements.txt') }} + key: ${{ runner.os }}-${{ matrix.python}}-conda-v12-${{ hashFiles('python/requirements/CI-tests-conda/requirements.txt') }}-${{ hashFiles('python/requirements/CI-tests-pip/requirements.txt') }} - name: Install Conda uses: conda-incubator/setup-miniconda@v2 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index e9e16c64e5..9f3b31ef3d 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -1,4 +1,4 @@ -lshmm==0.0.4; python_version < '3.11' +lshmm==0.0.4 numpy==1.21.6; python_version < '3.11' # Held at 1.21.6 for Python 3.7 compatibility numpy==1.24.1; python_version > '3.10' pytest==7.1.3 diff --git a/python/tests/test_genotype_matching_fb.py b/python/tests/test_genotype_matching_fb.py index 248382e913..761eadf403 100644 --- a/python/tests/test_genotype_matching_fb.py +++ b/python/tests/test_genotype_matching_fb.py @@ -754,7 +754,6 @@ def compute_next_probability_dict( query_is_missing, ): mu = self.mu[site_id] - template_is_hom = np.logical_not(template_is_het) if query_is_missing: diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 55f102939c..b09ebcc005 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2021 Tskit Developers +# Copyright (c) 2019-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,332 +20,55 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """ -Python implementation of the Li and Stephens algorithms. +Python implementation of the Li and Stephens forwards and backwards algorithms. """ import itertools -import unittest +import lshmm as ls import msprime import numpy as np -import pytest -import _tskit # TMP import tskit -from tests import tsutil +MISSING = -1 -def in_sorted(values, j): - # Take advantage of the fact that the numpy array is sorted. - ret = False - index = np.searchsorted(values, j) - if index < values.shape[0]: - ret = values[index] == j - return ret - -def ls_forward_matrix_naive(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS forward algorithm using Python loops. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - S = np.zeros(m) - f = np.zeros(n) + 1 / n - - for el in range(0, m): - for j in range(n): - # NOTE Careful with the difference between this expression and - # the Viterbi algorithm below. This depends on the different - # normalisation approach. - p_t = f[j] * (1 - rho[el]) + rho[el] / n - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - f[j] = p_t * p_e - S[el] = np.sum(f) - # TODO need to handle the 0 case. - assert S[el] > 0 - f /= S[el] - F[el] = f - return F, S - - -def ls_viterbi_naive(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS Viterbi algorithm using Python loops. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - L = np.ones(n) - T = [set() for _ in range(m)] - T_dest = np.zeros(m, dtype=int) - - for el in range(m): - # The calculation below is undefined otherwise. - if len(alleles[el]) > 1: - assert mu[el] <= 1 / (len(alleles[el]) - 1) - L_next = np.zeros(n) - for j in range(n): - # NOTE Careful with the difference between this expression and - # the Forward algorithm above. This depends on the different - # normalisation approach. - p_no_recomb = L[j] * (1 - rho[el] + rho[el] / n) - p_recomb = rho[el] / n - if p_no_recomb > p_recomb: - p_t = p_no_recomb - else: - p_t = p_recomb - T[el].add(j) - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - L_next[j] = p_t * p_e - L = L_next - j = np.argmax(L) - T_dest[el] = j - if L[j] == 0: - assert mu[el] == 0 - raise ValueError( - "Trying to match non-existent allele with zero mutation rate" - ) - L /= L[j] - - P = np.zeros(m, dtype=int) - P[m - 1] = T_dest[m - 1] - for el in range(m - 1, 0, -1): - j = P[el] - if j in T[el]: - j = T_dest[el - 1] - P[el - 1] = j - return P - - -def ls_viterbi_vectorised(h, alleles, G, rho, mu): - # We must have a non-zero mutation rate, or we'll end up with - # division by zero problems. - # assert np.all(mu > 0) - - m, n = G.shape - alleles = check_alleles(alleles, m) - V = np.ones(n) - T = [None for _ in range(m)] - max_index = np.zeros(m, dtype=int) - - for site in range(m): - # Transition - p_neq = rho[site] / n - p_t = (1 - rho[site] + rho[site] / n) * V - recombinations = np.where(p_neq > p_t)[0] - p_t[recombinations] = p_neq - T[site] = recombinations - # Emission - p_e = np.zeros(n) + mu[site] - index = G[site] == h[site] - if h[site] == tskit.MISSING_DATA: - # Missing data is considered equal to everything - index[:] = True - p_e[index] = 1 - (len(alleles[site]) - 1) * mu[site] - V = p_t * p_e - # Normalise - max_index[site] = np.argmax(V) - # print(site, ":", V) - if V[max_index[site]] == 0: - assert mu[site] == 0 - raise ValueError( - "Trying to match non-existent allele with zero mutation rate" - ) - V /= V[max_index[site]] - - # Traceback - P = np.zeros(m, dtype=int) - site = m - 1 - P[site] = max_index[site] - while site > 0: - j = P[site] - if in_sorted(T[site], j): - j = max_index[site - 1] - P[site - 1] = j - site -= 1 - return P - - -def check_alleles(alleles, num_sites): +def check_alleles(alleles, m): """ Checks the specified allele list and returns a list of lists of alleles of length num_sites. - If alleles is a 1D list of strings, assume that this list is used for each site and return num_sites copies of this list. - Otherwise, raise a ValueError if alleles is not a list of length num_sites. """ if isinstance(alleles[0], str): - return [alleles for _ in range(num_sites)] - if len(alleles) != num_sites: + return [alleles for _ in range(m)], np.int8([len(alleles) for _ in range(m)]) + if len(alleles) != m: raise ValueError("Malformed alleles list") - return alleles + n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles]) + return alleles, n_alleles -def ls_forward_matrix(h, alleles, G, rho, mu): +def mirror_coordinates(ts): """ - Simple matrix based method for LS forward algorithm using numpy vectorisation. + Returns a copy of the specified tree sequence in which all + coordinates x are transformed into L - x. """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - S = np.zeros(m) - f = np.zeros(n) + 1 / n - p_e = np.zeros(n) - - for el in range(0, m): - p_t = f * (1 - rho[el]) + rho[el] / n - eq = G[el] == h[el] - if h[el] == tskit.MISSING_DATA: - # Missing data is equal to everything - eq[:] = True - p_e[:] = mu[el] - p_e[eq] = 1 - (len(alleles[el]) - 1) * mu[el] - f = p_t * p_e - S[el] = np.sum(f) - # TODO need to handle the 0 case. - assert S[el] > 0 - f /= S[el] - F[el] = f - return F, S - - -def forward_matrix_log_proba(F, S): - """ - Given the specified forward matrix and scaling factor array, return the - overall log probability of the input haplotype. - """ - return np.sum(np.log(S)) - np.log(np.sum(F[-1])) - - -def ls_forward_matrix_unscaled(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS forward algorithm. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - f = np.zeros(n) + 1 / n - - for el in range(0, m): - s = np.sum(f) - for j in range(n): - p_t = f[j] * (1 - rho[el]) + s * rho[el] / n - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - f[j] = p_t * p_e - F[el] = f - return F - - -# TODO change this to use the log_proba function below. -def ls_path_probability(h, path, G, rho, mu): - """ - Returns the probability of the specified path through the genotypes for the - specified haplotype. - """ - # Assuming num_alleles = 2 - assert rho[0] == 0 - m, n = G.shape - # TODO It's not entirely clear why we're starting with a proba of 1 / n for the - # model. This was done because it made it easier to compare with an existing - # HMM implementation. Need to figure this one out when writing up. - proba = 1 / n - for site in range(0, m): - pe = mu[site] - if h[site] == G[site, path[site]] or h[site] == tskit.MISSING_DATA: - pe = 1 - mu[site] - pt = rho[site] / n - if site == 0 or path[site] == path[site - 1]: - pt = 1 - rho[site] + rho[site] / n - proba *= pt * pe - return proba - - -def ls_path_log_probability(h, path, alleles, G, rho, mu): - """ - Returns the log probability of the specified path through the genotypes for the - specified haplotype. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - # TODO It's not entirely clear why we're starting with a proba of 1 / n for the - # model. This was done because it made it easier to compare with an existing - # HMM implementation. Need to figure this one out when writing up. - log_proba = np.log(1 / n) - for site in range(0, m): - if len(alleles[site]) > 1: - assert mu[site] <= 1 / (len(alleles[site]) - 1) - pe = mu[site] - if h[site] == G[site, path[site]] or h[site] == tskit.MISSING_DATA: - pe = 1 - (len(alleles[site]) - 1) * mu[site] - assert 0 <= pe <= 1 - pt = rho[site] / n - if site == 0 or path[site] == path[site - 1]: - pt = 1 - rho[site] + rho[site] / n - assert 0 <= pt <= 1 - log_proba += np.log(pt) + np.log(pe) - return log_proba - - -def ls_forward_tree(h, alleles, ts, rho, mu, precision=30, use_lib=True): - """ - Forward matrix computation based on a tree sequence. - """ - if use_lib: - acgt_alleles = tuple(alleles) == tskit.ALLELES_ACGT - ls_hmm = _tskit.LsHmm( - ts.ll_tree_sequence, - recombination_rate=rho, - mutation_rate=mu, - precision=precision, - acgt_alleles=acgt_alleles, - ) - cm = _tskit.CompressedMatrix(ts.ll_tree_sequence) - ls_hmm.forward_matrix(h, cm) - return cm - else: - fa = ForwardAlgorithm(ts, rho, mu, alleles, precision=precision) - return fa.run(h) - - -def ls_viterbi_tree(h, alleles, ts, rho, mu, precision=30, use_lib=True): - """ - Viterbi path computation based on a tree sequence. - """ - if use_lib: - acgt_alleles = tuple(alleles) == tskit.ALLELES_ACGT - ls_hmm = _tskit.LsHmm( - ts.ll_tree_sequence, - recombination_rate=rho, - mutation_rate=mu, - precision=precision, - acgt_alleles=acgt_alleles, - ) - vm = _tskit.ViterbiMatrix(ts.ll_tree_sequence) - ls_hmm.viterbi_matrix(h, vm) - return vm - else: - va = ViterbiAlgorithm(ts, rho, mu, alleles, precision=precision) - return va.run(h) + L = ts.sequence_length + tables = ts.dump_tables() + left = tables.edges.left + right = tables.edges.right + tables.edges.left = L - right + tables.edges.right = L - left + tables.sites.position = L - tables.sites.position # + 1 + # TODO migrations. + tables.sort() + return tables.tree_sequence() class ValueTransition: - """ - Simple struct holding value transition values. - """ + """Simple struct holding value transition values.""" def __init__(self, tree_node=-1, value=-1, value_index=-1): self.tree_node = tree_node @@ -353,7 +76,11 @@ def __init__(self, tree_node=-1, value=-1, value_index=-1): self.value_index = value_index def copy(self): - return ValueTransition(self.tree_node, self.value, self.value_index) + return ValueTransition( + self.tree_node, + self.value, + self.value_index, + ) def __repr__(self): return repr(self.__dict__) @@ -367,11 +94,12 @@ class LsHmmAlgorithm: Abstract superclass of Li and Stephens HMM algorithm. """ - def __init__(self, ts, rho, mu, alleles, precision=10): + def __init__( + self, ts, rho, mu, alleles, n_alleles, precision=10, scale_mutation=False + ): self.ts = ts self.mu = mu self.rho = rho - self.alleles = check_alleles(alleles, ts.num_sites) self.precision = precision # The array of ValueTransitions. self.T = [] @@ -386,6 +114,10 @@ def __init__(self, ts, rho, mu, alleles, precision=10): self.parent = np.zeros(self.ts.num_nodes, dtype=int) - 1 self.tree = tskit.Tree(self.ts) self.output = None + # Vector of the number of alleles at each site + self.n_alleles = n_alleles + self.alleles = alleles + self.scale_mutation_based_on_n_alleles = scale_mutation def check_integrity(self): M = [st.tree_node for st in self.T if st.tree_node != -1] @@ -422,10 +154,6 @@ def compute(u, parent_state): for j in range(num_values): value_count[j] += child[j] max_value_count = np.max(value_count) - # NOTE: we need to set the set to zero here because we actually - # visit some nodes more than once during the postorder traversal. - # This would seem to be wasteful, so we should revisit this when - # cleaning up the algorithm logic. optimal_set[u, :] = 0 optimal_set[u, value_count == max_value_count] = 1 @@ -566,9 +294,9 @@ def update_probabilities(self, site, haplotype_state): T = self.T alleles = self.alleles[site.id] allelic_state = self.allelic_state - # Set the allelic_state for this site. allelic_state[tree.root] = alleles.index(site.ancestral_state) + for mutation in site.mutations: u = mutation.node allelic_state[u] = alleles.index(mutation.derived_state) @@ -590,8 +318,7 @@ def update_probabilities(self, site, haplotype_state): v = tree.parent(v) assert v != -1 match = ( - haplotype_state == tskit.MISSING_DATA - or haplotype_state == allelic_state[v] + haplotype_state == MISSING or haplotype_state == allelic_state[v] ) st.value = self.compute_next_probability(site.id, st.value, match, u) @@ -600,31 +327,41 @@ def update_probabilities(self, site, haplotype_state): for mutation in site.mutations: allelic_state[mutation.node] = -1 - def process_site(self, site, haplotype_state): - # print(site.id, "num_transitions=", len(self.T)) - self.update_probabilities(site, haplotype_state) - # FIXME We don't want to call compress here. - # What we really want to do is just call compress after - # the values have been normalised and rounded. However, we can't - # compute the normalisation factor in the forwards algorithm without - # the N counts (number of samples directly below each value transition - # in T), and these are currently computed during compress. So to make - # things work for now we call compress before and put up with having - # a slightly less than optimally compressed output matrix. It might - # end up that this makes no difference and compressing the - # pre-rounded values is basically the same thing. - self.compress() - s = self.compute_normalisation_factor() - for st in self.T: - if st.tree_node != tskit.NULL: - st.value /= s - st.value = round(st.value, self.precision) - # *This* is where we want to compress (and can, for viterbi). - # self.compress() - self.output.store_site(site.id, s, [(st.tree_node, st.value) for st in self.T]) - - def run(self, h): + def process_site(self, site, haplotype_state, forwards=True): + if forwards: + # Forwards algorithm, or forwards pass in Viterbi + self.update_probabilities(site, haplotype_state) + self.compress() + s = self.compute_normalisation_factor() + for st in self.T: + if st.tree_node != tskit.NULL: + st.value /= s + st.value = round(st.value, self.precision) + self.output.store_site( + site.id, s, [(st.tree_node, st.value) for st in self.T] + ) + else: + # Backwards algorithm + self.output.store_site( + site.id, + self.output.normalisation_factor[site.id], + [(st.tree_node, st.value) for st in self.T], + ) + self.update_probabilities(site, haplotype_state) + self.compress() + b_last_sum = self.compute_normalisation_factor() + s = self.output.normalisation_factor[site.id] + for st in self.T: + if st.tree_node != tskit.NULL: + st.value = ( + self.rho[site.id] / self.ts.num_samples + ) * b_last_sum + (1 - self.rho[site.id]) * st.value + st.value /= s + st.value = round(st.value, self.precision) + + def run_forward(self, h): n = self.ts.num_samples + self.tree.clear() for u in self.ts.samples(): self.T_index[u] = len(self.T) self.T.append(ValueTransition(tree_node=u, value=1 / n)) @@ -634,6 +371,17 @@ def run(self, h): self.process_site(site, h[site.id]) return self.output + def run_backward(self, h): + self.tree.clear() + for u in self.ts.samples(): + self.T_index[u] = len(self.T) + self.T.append(ValueTransition(tree_node=u, value=1)) + while self.tree.next(): + self.update_tree() + for site in self.tree.sites(): + self.process_site(site, h[site.id], forwards=False) + return self.output + def compute_normalisation_factor(self): raise NotImplementedError() @@ -650,12 +398,16 @@ class CompressedMatrix: values are on the path). """ - def __init__(self, ts): + def __init__(self, ts, normalisation_factor=None): self.ts = ts self.num_sites = ts.num_sites self.num_samples = ts.num_samples self.value_transitions = [None for _ in range(self.num_sites)] - self.normalisation_factor = np.zeros(self.num_sites) + if normalisation_factor is None: + self.normalisation_factor = np.zeros(self.num_sites) + else: + self.normalisation_factor = normalisation_factor + assert len(self.normalisation_factor) == self.num_sites def store_site(self, site, normalisation_factor, value_transitions): self.normalisation_factor[site] = normalisation_factor @@ -688,39 +440,11 @@ def decode(self): class ForwardMatrix(CompressedMatrix): - """ - Class representing a compressed forward matrix. - """ - - -class ForwardAlgorithm(LsHmmAlgorithm): - """ - Runs the Li and Stephens forward algorithm. - """ - - def __init__(self, ts, rho, mu, alleles, precision=10): - super().__init__(ts, rho, mu, alleles, precision) - self.output = ForwardMatrix(ts) - - def compute_normalisation_factor(self): - s = 0 - for j, st in enumerate(self.T): - assert st.tree_node != tskit.NULL - assert self.N[j] > 0 - s += self.N[j] * st.value - return s + """Class representing a compressed forward matrix.""" - def compute_next_probability(self, site_id, p_last, is_match, node): - rho = self.rho[site_id] - mu = self.mu[site_id] - alleles = self.alleles[site_id] - n = self.ts.num_samples - p_t = p_last * (1 - rho) + rho / n - p_e = mu - if is_match: - p_e = 1 - (len(alleles) - 1) * mu - return p_t * p_e +class BackwardMatrix(CompressedMatrix): + """Class representing a compressed backward matrix.""" class ViterbiMatrix(CompressedMatrix): @@ -730,6 +454,8 @@ class ViterbiMatrix(CompressedMatrix): def __init__(self, ts): super().__init__(ts) + # Tuple containing the site, the node in the tree, and whether + # recombination is required self.recombination_required = [(-1, 0, False)] def add_recombination_required(self, site, node, required): @@ -801,13 +527,144 @@ def traceback(self): return match +class ForwardAlgorithm(LsHmmAlgorithm): + """Runs the Li and Stephens forward algorithm.""" + + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = ForwardMatrix(ts) + + def compute_normalisation_factor(self): + s = 0 + for j, st in enumerate(self.T): + assert st.tree_node != tskit.NULL + assert self.N[j] > 0 + s += self.N[j] * st.value + return s + + def compute_next_probability( + self, site_id, p_last, is_match, node + ): # Note node only used in Viterbi + rho = self.rho[site_id] + mu = self.mu[site_id] + n = self.ts.num_samples + n_alleles = self.n_alleles[site_id] + + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + + p_t = p_last * (1 - rho) + rho / n + return p_t * p_e + + +class BackwardAlgorithm(LsHmmAlgorithm): + """Runs the Li and Stephens backward algorithm.""" + + def __init__( + self, + ts, + rho, + mu, + alleles, + n_alleles, + normalisation_factor, + scale_mutation=False, + precision=10, + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = BackwardMatrix(ts, normalisation_factor) + + def compute_normalisation_factor(self): + s = 0 + for j, st in enumerate(self.T): + assert st.tree_node != tskit.NULL + assert self.N[j] > 0 + s += self.N[j] * st.value + return s + + def compute_next_probability( + self, site_id, p_next, is_match, node + ): # Note node only used in Viterbi + mu = self.mu[site_id] + n_alleles = self.n_alleles[site_id] + + if self.scale_mutation_based_on_n_alleles: + if is_match: + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + else: + if n_alleles == 1: + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + return p_next * p_e + + class ViterbiAlgorithm(LsHmmAlgorithm): """ Runs the Li and Stephens Viterbi algorithm. """ - def __init__(self, ts, rho, mu, alleles, precision=10): - super().__init__(ts, rho, mu, alleles, precision) + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) self.output = ViterbiMatrix(ts) def compute_normalisation_factor(self): @@ -825,8 +682,8 @@ def compute_normalisation_factor(self): def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] mu = self.mu[site_id] - alleles = self.alleles[site_id] n = self.ts.num_samples + n_alleles = self.n_alleles[site_id] p_no_recomb = p_last * (1 - rho + rho / n) p_recomb = rho / n @@ -837,474 +694,427 @@ def compute_next_probability(self, site_id, p_last, is_match, node): p_t = p_recomb recombination_required = True self.output.add_recombination_required(site_id, node, recombination_required) - p_e = mu - if is_match: - p_e = 1 - (len(alleles) - 1) * mu - return p_t * p_e + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) -################################################################ -# Tests -################################################################ + return p_t * p_e -class LiStephensBase: +def ls_forward_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) + for j in range(ts.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts.num_sites) + + """Forward matrix computation based on a tree sequence.""" + fa = ForwardAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + return fa.run_forward(h) + + +def ls_backward_tree( + h, ts_mirror, rho, mu, normalisation_factor, precision=30, alleles=None +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts_mirror.genotype_matrix()[j, :], h[j]))) + for j in range(ts_mirror.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts_mirror.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts_mirror.num_sites) + + """Backward matrix computation based on a tree sequence.""" + ba = BackwardAlgorithm( + ts_mirror, + rho, + mu, + alleles, + n_alleles, + normalisation_factor, + precision=precision, + ) + return ba.run_backward(h) + + +def ls_viterbi_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) + for j in range(ts.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts.num_sites) """ - Superclass of Li and Stephens tests. + Viterbi path computation based on a tree sequence. """ - - def assertCompressedMatricesEqual(self, cm1, cm2): - """ - Checks that the specified compressed matrices contain the same data. - """ - A1 = cm1.decode() - A2 = cm2.decode() - assert np.allclose(A1, A2) - assert A1.shape == A2.shape - assert cm1.num_sites == cm2.num_sites - nf1 = cm1.normalisation_factor - nf2 = cm1.normalisation_factor - assert np.allclose(nf1, nf2) - assert nf1.shape == nf2.shape - # It seems that we can't rely on the number of transitions in the two - # implementations being equal, which seems odd given that we should - # be doing things identically. Still, once the decoded matrices are the - # same then it seems highly likely to be correct. - - # if not np.array_equal(cm1.num_transitions, cm2.num_transitions): - # print() - # print(cm1.num_transitions) - # print(cm2.num_transitions) - # self.assertTrue(np.array_equal(cm1.num_transitions, cm2.num_transitions)) - # for j in range(cm1.num_sites): - # s1 = dict(cm1.get_site(j)) - # s2 = dict(cm2.get_site(j)) - # self.assertEqual(set(s1.keys()), set(s2.keys())) - # for key in s1.keys(): - # self.assertAlmostEqual(s1[key], s2[key]) - - def example_haplotypes(self, ts, alleles, num_random=10, seed=2): - rng = np.random.RandomState(seed) - H = ts.genotype_matrix(alleles=alleles).T - haplotypes = [H[0], H[-1]] - for _ in range(num_random): - # Choose a random path through H - p = rng.randint(0, ts.num_samples, ts.num_sites) - h = H[p, np.arange(ts.num_sites)] - haplotypes.append(h) - h = H[0].copy() - h[-1] = tskit.MISSING_DATA - haplotypes.append(h) - h = H[0].copy() - h[ts.num_sites // 2] = tskit.MISSING_DATA - haplotypes.append(h) - # All missing is OK tool - h = H[0].copy() - h[:] = tskit.MISSING_DATA - haplotypes.append(h) - return haplotypes - - def example_parameters(self, ts, alleles, seed=1): - """ - Returns an iterator over combinations of haplotype, recombination and mutation - rates. - """ - rng = np.random.RandomState(seed) - haplotypes = self.example_haplotypes(ts, alleles, seed=seed) - - # This is the exact matching limit. - rho = np.zeros(ts.num_sites) + 0.01 - mu = np.zeros(ts.num_sites) - rho[0] = 0 - for h in haplotypes: - yield h, rho, mu + va = ViterbiAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + return va.run_forward(h) + + +class LSBase: + """Superclass of Li and Stephens tests.""" + + def example_haplotypes(self, ts): + + H = ts.genotype_matrix() + s = H[:, 0].reshape(1, H.shape[0]) + H = H[:, 1:] + + haplotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]), + ] + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + haplotypes.append(s_tmp) + + return H, haplotypes + + def example_parameters_haplotypes(self, ts, seed=42): + """Returns an iterator over combinations of haplotype, + recombination and mutation rates.""" + np.random.seed(seed) + H, haplotypes = self.example_haplotypes(ts) + n = H.shape[1] + m = ts.get_num_sites() # Here we have equal mutation and recombination - rho = np.zeros(ts.num_sites) + 0.01 - mu = np.zeros(ts.num_sites) + 0.01 - rho[0] = 0 - for h in haplotypes: - yield h, rho, mu + r = np.zeros(m) + 0.01 + mu = np.zeros(m) + 0.01 + r[0] = 0 + + for s in haplotypes: + yield n, H, s, r, mu # Mixture of random and extremes - rhos = [ - np.zeros(ts.num_sites) + 0.999, - np.zeros(ts.num_sites) + 1e-6, - rng.uniform(0, 1, ts.num_sites), - ] - # mu can't be more than 1 / 3 if we have 4 alleles - mus = [ - np.zeros(ts.num_sites) + 0.33, - np.zeros(ts.num_sites) + 1e-6, - rng.uniform(0, 0.33, ts.num_sites), - ] - for h, rho, mu in itertools.product(haplotypes, rhos, mus): - rho[0] = 0 - yield h, rho, mu + rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] + mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] + + for s, r, mu in itertools.product(haplotypes, rs, mus): + r[0] = 0 + yield n, H, s, r, mu def assertAllClose(self, A, B): - assert np.allclose(A, B) + """Assert that all entries of two matrices are 'close'""" + assert np.allclose(A, B, rtol=1e-5, atol=1e-8) + + # Define a bunch of very small tree-sequences for testing a collection + # of parameters on + def test_simple_n_10_no_recombination(self): + ts = msprime.simulate( + 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 + ) + assert ts.num_sites > 3 + self.verify(ts) - def test_simple_n_4_no_recombination(self): - ts = msprime.simulate(4, recombination_rate=0, mutation_rate=0.5, random_seed=1) + def test_simple_n_10_no_recombination_high_mut(self): + ts = msprime.simulate(10, recombination_rate=0, mutation_rate=3, random_seed=42) assert ts.num_sites > 3 self.verify(ts) - def test_simple_n_3(self): - ts = msprime.simulate(3, recombination_rate=2, mutation_rate=7, random_seed=2) - assert ts.num_sites > 5 + def test_simple_n_10_no_recombination_higher_mut(self): + ts = msprime.simulate(20, recombination_rate=0, mutation_rate=3, random_seed=42) + assert ts.num_sites > 3 self.verify(ts) - def test_simple_n_7(self): - ts = msprime.simulate(7, recombination_rate=2, mutation_rate=5, random_seed=2) + def test_simple_n_6(self): + ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42) assert ts.num_sites > 5 self.verify(ts) - def test_simple_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=2) - assert ts.num_trees > 15 + def test_simple_n_8(self): + ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42) assert ts.num_sites > 5 self.verify(ts) - def test_simple_n_15(self): - ts = msprime.simulate(15, recombination_rate=2, mutation_rate=5, random_seed=2) + def test_simple_n_8_high_recombination(self): + ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42) + assert ts.num_trees > 15 assert ts.num_sites > 5 self.verify(ts) - def test_jukes_cantor_n_3(self): - ts = msprime.simulate(3, mutation_rate=2, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=10, seed=4) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=20, mu=5, seed=4) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_n_15(self): - ts = msprime.simulate(15, mutation_rate=2, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=0.1, seed=10) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_balanced_ternary(self): - ts = tskit.Tree.generate_balanced(27, arity=3).tree_sequence - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=0.1, seed=10) - self.verify(ts, tskit.ALLELES_ACGT) - - @pytest.mark.skip(reason="Not supporting internal samples yet") - def test_ancestors_n_3(self): - ts = msprime.simulate(3, recombination_rate=2, mutation_rate=7, random_seed=2) + def test_simple_n_16(self): + ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42) assert ts.num_sites > 5 - tables = ts.dump_tables() - print(tables.nodes) - tables.nodes.flags = np.ones_like(tables.nodes.flags) - print(tables.nodes) - ts = tables.tree_sequence() self.verify(ts) + # # Define a bunch of very small tree-sequences for testing a collection + # # of parameters on + # def test_simple_n_10_no_recombination_blah(self): + # ts = msprime.sim_ancestry( + # samples=10, + # recombination_rate=0, + # random_seed=42, + # sequence_length=10, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-5, random_seed=42) + # assert ts.num_sites > 3 + # self.verify(ts) + + # def test_simple_n_6_blah(self): + # ts = msprime.sim_ancestry( + # samples=6, + # recombination_rate=1e-4, + # random_seed=42, + # sequence_length=40, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-3, random_seed=42) + # assert ts.num_sites > 5 + # self.verify(ts) + + # def test_simple_n_8_blah(self): + # ts = msprime.sim_ancestry( + # samples=8, + # recombination_rate=1e-4, + # random_seed=42, + # sequence_length=20, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + # assert ts.num_sites > 5 + # assert ts.num_trees > 15 + # self.verify(ts) + + # def test_simple_n_16_blah(self): + # ts = msprime.sim_ancestry( + # samples=16, + # recombination_rate=1e-2, + # random_seed=42, + # sequence_length=20, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + # assert ts.num_sites > 5 + # self.verify(ts) + + def verify(self, ts): + raise NotImplementedError() -@pytest.mark.slow -class ForwardAlgorithmBase(LiStephensBase): - """ - Base for forward algorithm tests. - """ +class FBAlgorithmBase(LSBase): + """Base for forwards backwards algorithm tests.""" -class TestNumpyMatrixMethod(ForwardAlgorithmBase): - """ - Tests that we compute the same values from the numpy matrix method as - the naive algorithm. - """ - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - for h, rho, mu in self.example_parameters(ts, alleles): - F1, S1 = ls_forward_matrix(h, alleles, G, rho, mu) - F2, S2 = ls_forward_matrix_naive(h, alleles, G, rho, mu) - self.assertAllClose(F1, F2) - self.assertAllClose(S1, S2) +class VitAlgorithmBase(LSBase): + """Base for viterbi algoritm tests.""" -class ViterbiAlgorithmBase(LiStephensBase): - """ - Base for viterbi algoritm tests. - """ +class TestMirroringHap(FBAlgorithmBase): + """Tests that mirroring the tree sequence and running forwards and backwards + algorithms gives the same log-likelihood of observing the data.""" + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + # Note, need to remove the first sample from the ts, and ensure that + # invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_forward_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) -class TestExactMatchViterbi(ViterbiAlgorithmBase): - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - H = G.T - # print(H) - rho = np.zeros(ts.num_sites) + 0.1 - mu = np.zeros(ts.num_sites) - rho[0] = 0 - for h in H: - p1 = ls_viterbi_naive(h, alleles, G, rho, mu) - p2 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - cm1 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - p3 = cm1.traceback() - cm2 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=False) - p4 = cm1.traceback() - self.assertCompressedMatricesEqual(cm1, cm2) - - assert len(np.unique(p1)) == 1 - assert len(np.unique(p2)) == 1 - assert len(np.unique(p3)) == 1 - assert len(np.unique(p4)) == 1 - m1 = H[p1, np.arange(H.shape[1])] - assert np.array_equal(m1, h) - m2 = H[p2, np.arange(H.shape[1])] - assert np.array_equal(m2, h) - m3 = H[p3, np.arange(H.shape[1])] - assert np.array_equal(m3, h) - m4 = H[p3, np.arange(H.shape[1])] - assert np.array_equal(m4, h) - - -@pytest.mark.slow -class TestGeneralViterbi(ViterbiAlgorithmBase, unittest.TestCase): - def verify(self, ts, alleles=tskit.ALLELES_01): - # np.set_printoptions(linewidth=20000) - # np.set_printoptions(threshold=20000000) - G = ts.genotype_matrix(alleles=alleles) - # m, n = G.shape - for h, rho, mu in self.example_parameters(ts, alleles): - # print("h = ", h) - # print("rho=", rho) - # print("mu = ", mu) - p1 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - p2 = ls_viterbi_naive(h, alleles, G, rho, mu) - cm1 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - p3 = cm1.traceback() - cm2 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=False) - p4 = cm1.traceback() - self.assertCompressedMatricesEqual(cm1, cm2) - # print() - # m1 = H[p1, np.arange(m)] - # m2 = H[p2, np.arange(m)] - # m3 = H[p3, np.arange(m)] - # count = np.unique(p1).shape[0] - # print() - # print("\tp1 = ", p1) - # print("\tp2 = ", p2) - # print("\tp3 = ", p3) - # print("\tm1 = ", m1) - # print("\tm2 = ", m2) - # print("\t h = ", h) - proba1 = ls_path_log_probability(h, p1, alleles, G, rho, mu) - proba2 = ls_path_log_probability(h, p2, alleles, G, rho, mu) - proba3 = ls_path_log_probability(h, p3, alleles, G, rho, mu) - proba4 = ls_path_log_probability(h, p4, alleles, G, rho, mu) - # print("\t P = ", proba1, proba2) - self.assertAlmostEqual(proba1, proba2, places=6) - self.assertAlmostEqual(proba1, proba3, places=6) - self.assertAlmostEqual(proba1, proba4, places=6) - - -class TestMissingHaplotypes(LiStephensBase): - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - H = G.T - - rho = np.zeros(ts.num_sites) + 0.1 - rho[0] = 0 - mu = np.zeros(ts.num_sites) + 0.001 - - # When everything is missing data we should have no recombinations. - h = H[0].copy() - h[:] = tskit.MISSING_DATA - path = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert np.all(path == 0) - cm = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - # For the tree base algorithm it's not simple which particular sample - # gets chosen. - path = cm.traceback() - assert len(set(path)) == 1 - - # TODO Not clear what else we can check about missing data. - - -class TestForwardMatrixScaling(ForwardAlgorithmBase, unittest.TestCase): - """ - Tests that we get the correct values from scaling version of the matrix - algorithm works correctly. - """ + ts_check_mirror = mirror_coordinates(ts_check) + r_flip = np.insert(np.flip(r)[:-1], 0, 0) + cm_mirror = ls_forward_tree( + np.flip(s[0, :]), ts_check_mirror, r_flip, np.flip(mu) + ) + ll_mirror_tree = np.sum(np.log10(cm_mirror.normalisation_factor)) + self.assertAllClose(ll_tree, ll_mirror_tree) + + # Ensure that the decoded matrices are the same + F_mirror_matrix, c, ll = ls.forwards( + np.flip(H, axis=0), + np.flip(s, axis=1), + r_flip, + mutation_rate=np.flip(mu), + scale_mutation_based_on_n_alleles=False, + ) - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - computed_log_proba = False - for h, rho, mu in self.example_parameters(ts, alleles): - F_unscaled = ls_forward_matrix_unscaled(h, alleles, G, rho, mu) - F, S = ls_forward_matrix(h, alleles, G, rho, mu) - column = np.atleast_2d(np.cumprod(S)).T - F_scaled = F * column - self.assertAllClose(F_scaled, F_unscaled) - log_proba1 = forward_matrix_log_proba(F, S) - psum = np.sum(F_unscaled[-1]) - # If the computed probability is close to zero, there's no point in - # computing. - if psum > 1e-20: - computed_log_proba = True - log_proba2 = np.log(psum) - self.assertAlmostEqual(log_proba1, log_proba2) - assert computed_log_proba - - -class TestForwardTree(ForwardAlgorithmBase): - """ - Tests that the tree algorithm computes the same forward matrix as the - simple method. - """ + self.assertAllClose(F_mirror_matrix, cm_mirror.decode()) + self.assertAllClose(ll, ll_tree) - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - for h, rho, mu in self.example_parameters(ts, alleles): - F, S = ls_forward_matrix(h, alleles, G, rho, mu) - cm1 = ls_forward_tree(h, alleles, ts, rho, mu, use_lib=True) - cm2 = ls_forward_tree(h, alleles, ts, rho, mu, use_lib=False) - self.assertCompressedMatricesEqual(cm1, cm2) - Ft = cm1.decode() - self.assertAllClose(S, cm1.normalisation_factor) - self.assertAllClose(F, Ft) +class TestForwardHapTree(FBAlgorithmBase): + """Tests that the tree algorithm computes the same forward matrix as the + simple method.""" -class TestAllPaths(unittest.TestCase): - """ - Tests that we compute the correct forward probablities if we sum over all - possible paths through the genotype matrix. - """ + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + for scale_mutation in [False, True]: + F, c, ll = ls.forwards( + H, + s, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=scale_mutation, + ) + # Note, need to remove the first sample from the ts, and ensure + # that invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_forward_tree( + s[0, :], + ts_check, + r, + mu, + scale_mutation_based_on_n_alleles=scale_mutation, + ) + self.assertAllClose(cm.decode(), F) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + self.assertAllClose(ll, ll_tree) - def verify(self, G, h): - m, n = G.shape - rho = np.zeros(m) + 0.1 - mu = np.zeros(m) + 0.01 - rho[0] = 0 - proba = 0 - for path in itertools.product(range(n), repeat=m): - proba += ls_path_probability(h, path, G, rho, mu) - - alleles = [["0", "1"] for _ in range(m)] - F = ls_forward_matrix_unscaled(h, alleles, G, rho, mu) - forward_proba = np.sum(F[-1]) - self.assertAlmostEqual(proba, forward_proba) - - def test_n3_m4(self): - G = np.array( - [ - # fmt: off - [1, 0, 0], - [0, 0, 1], - [1, 0, 1], - [0, 1, 1], - # fmt: on - ] - ) - self.verify(G, [0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0]) - def test_n4_m5(self): - G = np.array( - [ - # fmt: off - [1, 0, 0, 0], - [0, 0, 1, 1], - [1, 0, 1, 1], - [0, 1, 1, 0], - # fmt: on - ] - ) - self.verify(G, [0, 0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0, 0]) +class TestForwardBackwardTree(FBAlgorithmBase): + """Tests that the tree algorithm computes the same forward matrix as the + simple method.""" - def test_n5_m5(self): - G = np.zeros((5, 5), dtype=int) - np.fill_diagonal(G, 1) - self.verify(G, [0, 0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0, 0]) + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + F, c, ll = ls.forwards( + H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False + ) + B = ls.backwards( + H, + s, + c, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=False, + ) + # Note, need to remove the first sample from the ts, and ensure that + # invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + c_f = ls_forward_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(c_f.normalisation_factor)) + + ts_check_mirror = mirror_coordinates(ts_check) + r_flip = np.flip(r) + c_b = ls_backward_tree( + np.flip(s[0, :]), + ts_check_mirror, + r_flip, + np.flip(mu), + np.flip(c_f.normalisation_factor), + ) + B_tree = np.flip(c_b.decode(), axis=0) + F_tree = c_f.decode() -class TestBasicViterbi: - """ - Very simple tests of the Viterbi algorithm. - """ + self.assertAllClose(B, B_tree) + self.assertAllClose(F, F_tree) + self.assertAllClose(ll, ll_tree) - def verify_exact_match(self, G, h, path): - m, n = G.shape - rho = np.zeros(m) + 1e-9 - mu = np.zeros(m) # Set mu to zero exact match - rho[0] = 0 - alleles = [["0", "1"] for _ in range(m)] - path1 = ls_viterbi_naive(h, alleles, G, rho, mu) - path2 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert list(path1) == path - assert list(path2) == path - - def test_n2_m6_exact(self): - G = np.array( - [ - # fmt: off - [1, 0], - [1, 0], - [1, 0], - [0, 1], - [0, 1], - [0, 1], - # fmt: on - ] - ) - self.verify_exact_match(G, [1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 0], [1, 1, 1, 1, 1, 0]) - self.verify_exact_match(G, [0, 0, 0, 0, 1, 0], [1, 1, 1, 0, 1, 0]) - - def test_n3_m6_exact(self): - G = np.array( - [ - # fmt: off - [1, 0, 1], - [1, 0, 0], - [1, 0, 1], - [0, 1, 0], - [0, 1, 1], - [0, 1, 0], - # fmt: on - ] - ) - self.verify_exact_match(G, [1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]) - self.verify_exact_match(G, [1, 0, 1, 0, 1, 0], [2, 2, 2, 2, 2, 2]) - def test_n3_m6(self): - G = np.array( - [ - # fmt: off - [1, 0, 1], - [1, 0, 0], - [1, 0, 1], - [0, 1, 0], - [0, 1, 1], - [0, 1, 0], - # fmt: on - ] - ) +class TestTreeViterbiHap(VitAlgorithmBase): + """Test that we have the same log-likelihood between tree and matrix + implementations""" - m, n = G.shape - rho = np.zeros(m) + 1e-2 - mu = np.zeros(m) - rho[0] = 0 - alleles = [["0", "1"] for _ in range(m)] - h = np.ones(m, dtype=int) - path1 = ls_viterbi_naive(h, alleles, G, rho, mu) - - # Add in mutation at a very low rate. - mu[:] = 1e-8 - path2 = ls_viterbi_naive(h, alleles, G, rho, mu) - path3 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert np.array_equal(path1, path2) - assert np.array_equal(path2, path3) + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + path, ll = ls.viterbi( + H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False + ) + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_viterbi_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + self.assertAllClose(ll, ll_tree) + + # Now, need to ensure that the likelihood of the preferred path is + # the same as ll_tree (and ll). + path_tree = cm.traceback() + ll_check = ls.path_ll( + H, + s, + path_tree, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=False, + ) + self.assertAllClose(ll, ll_check)