diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 14decb18ff..214d473d41 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2036,22 +2036,28 @@ static void test_paper_ex_two_site(void) { tsk_treeseq_t ts; - double *result; - tsk_size_t s, result_size; + double result[27]; + tsk_size_t s, result_size, num_sample_sets; int ret; - double truth_one_set[6] = { 1, 0.1111111111111111, 0.1111111111111111, 1, 1, 1 }; - double truth_two_sets[12] = { 1, 1, 0.1111111111111111, 0.1111111111111111, - 0.1111111111111111, 0.1111111111111111, 1, 1, 1, 1, 1, 1 }; - double truth_three_sets[18] = { 1, 1, 0, 0.1111111111111111, 0.1111111111111111, 0, - 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; + double truth_one_set[9] = { 1, 0.1111111111111111, 0.1111111111111111, + 0.1111111111111111, 1, 1, 0.1111111111111111, 1, 1 }; + double truth_two_sets[18] = { 1, 1, 0.1111111111111111, 0.1111111111111111, + 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, + 1, 1, 1, 1, 0.1111111111111111, 0.1111111111111111, 1, 1, 1, 1 }; + double truth_three_sets[27] + = { 1, 1, 0, 0.1111111111111111, 0.1111111111111111, 0, 0.1111111111111111, + 0.1111111111111111, 0, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, + 1, 1, 1, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, 1, 1, 1 }; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); tsk_size_t sample_set_sizes[3]; - tsk_size_t num_sample_sets; tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); // First sample set contains all of the samples sample_set_sizes[0] = ts.num_samples; @@ -2059,14 +2065,18 @@ test_paper_ex_two_site(void) for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + result_size = num_sites * num_sites; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 6); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_one_set); - tsk_safe_free(result); // Second sample set contains all of the samples sample_set_sizes[1] = ts.num_samples; @@ -2075,13 +2085,12 @@ test_paper_ex_two_site(void) sample_sets[s] = (tsk_id_t) s - (tsk_id_t) ts.num_samples; } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 6); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_two_sets); - tsk_safe_free(result); // Third sample set contains the first two samples sample_set_sizes[2] = 2; @@ -2090,15 +2099,16 @@ test_paper_ex_two_site(void) sample_sets[s] = (tsk_id_t) s - (tsk_id_t) ts.num_samples * 2; } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 6); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_three_sets); - tsk_safe_free(result); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2145,83 +2155,86 @@ test_two_site_correlated_multiallelic(void) int ret; tsk_treeseq_t ts; - double *result; tsk_size_t s, result_size; - double truth_D[3] - = { 0.043209876543209874, -0.018518518518518517, 0.05555555555555555 }; - double truth_D2[3] - = { 0.023844603634269844, 0.02384460363426984, 0.02384460363426984 }; - double truth_r2[3] = { 1, 1, 1 }; - double truth_D_prime[3] - = { 0.7777777777777777, 0.4444444444444444, 0.6666666666666666 }; - double truth_r[3] - = { 0.18377223398316206, -0.12212786219416509, 0.2609542781331212 }; - double truth_Dz[3] - = { 0.0033870175616860566, 0.003387017561686057, 0.003387017561686057 }; - double truth_pi2[3] - = { 0.04579247743399549, 0.04579247743399549, 0.0457924774339955 }; + double truth_D[4] = { 0.043209876543209874, -0.018518518518518517, + -0.018518518518518517, 0.05555555555555555 }; + double truth_D2[4] = { 0.023844603634269844, 0.02384460363426984, + 0.02384460363426984, 0.02384460363426984 }; + double truth_r2[4] = { 1, 1, 1, 1 }; + double truth_D_prime[4] = { 0.7777777777777777, 0.4444444444444444, + 0.4444444444444444, 0.6666666666666666 }; + double truth_r[4] = { 0.18377223398316206, -0.12212786219416509, + -0.12212786219416509, 0.2609542781331212 }; + double truth_Dz[4] = { 0.0033870175616860566, 0.003387017561686057, + 0.003387017561686057, 0.003387017561686057 }; + double truth_pi2[4] = { 0.04579247743399549, 0.04579247743399549, + 0.04579247743399549, 0.0457924774339955 }; tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_size_t num_sample_sets = 1; + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + result_size = num_sites * num_sites; + double result[result_size]; for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D); - tsk_safe_free(result); - ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D2); - tsk_safe_free(result); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r2); - tsk_safe_free(result); - ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, - NULL, 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D_prime); - tsk_safe_free(result); - ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r); - tsk_safe_free(result); - ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_Dz); - tsk_safe_free(result); - ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, - 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_pi2); - tsk_safe_free(result); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2278,78 +2291,81 @@ test_two_site_uncorrelated_multiallelic(void) tsk_treeseq_t ts; int ret; - double *result; - tsk_size_t result_size; - - double truth_D[3] = { 0.05555555555555555, 0.0, 0.05555555555555555 }; - double truth_D2[3] = { 0.024691358024691357, 0.0, 0.024691358024691357 }; - double truth_r2[3] = { 1, 0, 1 }; - double truth_D_prime[3] = { 0.6666666666666665, 0.0, 0.6666666666666665 }; - double truth_r[3] = { 0.24999999999999997, 0.0, 0.24999999999999997 }; - double truth_Dz[3] = { 0.0, 0.0, 0.0 }; - double truth_pi2[3] - = { 0.04938271604938272, 0.04938271604938272, 0.04938271604938272 }; + + double truth_D[4] = { 0.05555555555555555, 0.0, 0.0, 0.05555555555555555 }; + double truth_D2[4] = { 0.024691358024691357, 0.0, 0.0, 0.024691358024691357 }; + double truth_r2[4] = { 1, 0, 0, 1 }; + double truth_D_prime[4] = { 0.6666666666666665, 0.0, 0.0, 0.6666666666666665 }; + double truth_r[4] = { 0.24999999999999997, 0.0, 0.0, 0.24999999999999997 }; + double truth_Dz[4] = { 0.0, 0.0, 0.0, 0.0 }; + double truth_pi2[4] = { 0.04938271604938272, 0.04938271604938272, + 0.04938271604938272, 0.04938271604938272 }; tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; + tsk_size_t s; tsk_size_t num_sample_sets = 1; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t result_size = num_sites * num_sites; + double result[result_size]; - for (tsk_size_t s = 0; s < ts.num_samples; s++) { + for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D); - tsk_safe_free(result); - ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D2); - tsk_safe_free(result); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r2); - tsk_safe_free(result); - ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, - NULL, 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D_prime); - tsk_safe_free(result); - ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r); - tsk_safe_free(result); - ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_Dz); - tsk_safe_free(result); - ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, - 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_pi2); - tsk_safe_free(result); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2386,44 +2402,122 @@ test_two_site_backmutation(void) "1 58 A 4\n"; int ret; - double *result; - tsk_size_t result_size; tsk_treeseq_t ts; tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_size_t num_sample_sets = 1; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t result_size = num_sites * num_sites; + double result[result_size]; + tsk_size_t s; - for (tsk_size_t s = 0; s < ts.num_samples; s++) { + double truth_r2[4] = { 0.999999999999999, 0.042923862278701, 0.042923862278701, 1. }; + + for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); - /* assert_arrays_almost_equal(result_size, result, truth_r2); */ - tsk_safe_free(result); + assert_arrays_almost_equal(result_size, result, truth_r2); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void -test_two_locus_stat_input_errors(void) +test_paper_ex_two_site_subset(void) { tsk_treeseq_t ts; - double *result; + double result[4]; + int ret; tsk_size_t s, result_size; + tsk_size_t sample_set_sizes[1]; + tsk_size_t num_sample_sets; + tsk_id_t row_sites[2] = { 0, 1 }; + tsk_id_t col_sites[2] = { 1, 2 }; + double result_truth_1[4] = { 0.1111111111111111, 0.1111111111111111, 1, 1 }; + double result_truth_2[1] = { 0.1111111111111111 }; + double result_truth_3[4] = { 0.1111111111111111, 1, 0.1111111111111111, 1 }; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + tsk_id_t sample_sets[ts.num_samples]; + + sample_set_sizes[0] = ts.num_samples; + num_sample_sets = 1; + for (s = 0; s < ts.num_samples; s++) { + sample_sets[s] = (tsk_id_t) s; + } + + result_size = 2 * 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + row_sites, 2, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_1); + + result_size = 1 * 1; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + col_sites[0] = 2; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + row_sites, 1, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_2); + + result_size = 2 * 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + row_sites[0] = 1; + row_sites[1] = 2; + col_sites[0] = 0; + col_sites[1] = 1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + row_sites, 2, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_3); + + tsk_treeseq_free(&ts); +} + +static void +test_two_locus_stat_input_errors(void) +{ + tsk_treeseq_t ts; int ret; tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1]; - tsk_size_t num_sample_sets; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; + tsk_size_t num_sample_sets = 1; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t result_size = num_sites * num_sites; + double result[result_size]; + tsk_size_t s; + + for (s = 0; s < ts.num_samples; s++) { + sample_sets[s] = (tsk_id_t) s; + } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } sample_set_sizes[0] = ts.num_samples; num_sample_sets = 1; @@ -2432,36 +2526,70 @@ test_two_locus_stat_input_errors(void) } sample_sets[1] = 0; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); sample_sets[1] = 1; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, TSK_STAT_SITE | TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, TSK_STAT_BRANCH, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); - ret = tsk_treeseq_r2(&ts, 0, sample_set_sizes, sample_sets, 0, NULL, 0, NULL, 0, - &result_size, &result); + ret = tsk_treeseq_r2(&ts, 0, sample_set_sizes, sample_sets, num_sites, row_sites, + num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); sample_set_sizes[0] = 0; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EMPTY_SAMPLE_SET); sample_set_sizes[0] = ts.num_samples; sample_sets[1] = 10; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); sample_sets[1] = 1; + row_sites[0] = 1000; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + row_sites[0] = 0; + + col_sites[num_sites - 1] = (tsk_id_t) num_sites; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + col_sites[num_sites - 1] = (tsk_id_t) num_sites - 1; + + row_sites[0] = 1; + row_sites[1] = 0; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites[0] = 0; + row_sites[1] = 1; + + row_sites[0] = 1; + row_sites[1] = 1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites[0] = 0; + row_sites[1] = 1; + + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SITE_POSITION); + tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2744,6 +2872,7 @@ main(int argc, char **argv) { "test_two_site_uncorrelated_multiallelic", test_two_site_uncorrelated_multiallelic }, { "test_two_site_backmutation", test_two_site_backmutation }, + { "test_paper_ex_two_site_subset", test_paper_ex_two_site_subset }, { "test_two_locus_stat_input_errors", test_two_locus_stat_input_errors }, { "test_simplest_divergence_matrix", test_simplest_divergence_matrix }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 56b2661c12..61fbf686d2 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2355,18 +2355,63 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, return ret; } +static void +get_site_row_col_indices(tsk_size_t n_rows, const tsk_id_t *row_sites, tsk_size_t n_cols, + const tsk_id_t *col_sites, tsk_id_t *sites, tsk_size_t *n_sites, tsk_size_t *row_idx, + tsk_size_t *col_idx) +{ + tsk_size_t r = 0, c = 0, s = 0; + + // Iterate rows and columns until we've exhaused one of the lists + while ((r < n_rows) && (c < n_cols)) { + if (row_sites[r] < col_sites[c]) { + sites[s] = row_sites[r]; + row_idx[r] = s; + s++; + r++; + } else if (col_sites[c] < row_sites[r]) { + sites[s] = col_sites[c]; + col_idx[c] = s; + s++; + c++; + } else { // row == col + sites[s] = row_sites[r]; + col_idx[c] = s; + row_idx[r] = s; + s++; + r++; + c++; + } + } + + // If there are any items remaining in the other list, drain it + while (r < n_rows) { + sites[s] = row_sites[r]; + row_idx[r] = s; + s++; + r++; + } + while (c < n_cols) { + sites[s] = col_sites[c]; + col_idx[c] = s; + s++; + c++; + } + *n_sites = s; +} + static int -get_mutation_samples( - const tsk_treeseq_t *ts, tsk_size_t *num_alleles, tsk_bit_array_t *allele_samples) +get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t n_sites, + tsk_size_t *num_alleles, tsk_bit_array_t *allele_samples) { int ret = 0; const tsk_flags_t *restrict flags = ts->tables->nodes.flags; const tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); const tsk_size_t *restrict site_muts_len = ts->site_mutations_length; - const tsk_site_t *restrict site; + tsk_site_t site; tsk_tree_t tree; tsk_bit_array_t all_samples_bits, mut_samples, mut_samples_row, out_row; - tsk_size_t max_muts_len, mut_offset, num_nodes, s, m, n; + tsk_size_t max_muts_len, site_offset, num_nodes, site_idx, s, m, n; tsk_id_t node, *nodes = NULL; void *tmp_nodes; @@ -2374,11 +2419,12 @@ get_mutation_samples( tsk_memset(&all_samples_bits, 0, sizeof(all_samples_bits)); max_muts_len = 0; - for (s = 0; s < ts->tables->sites.num_rows; s++) { - if (site_muts_len[s] > max_muts_len) { - max_muts_len = site_muts_len[s]; + for (s = 0; s < n_sites; s++) { + if (site_muts_len[sites[s]] > max_muts_len) { + max_muts_len = site_muts_len[sites[s]]; } } + // Allocate a bit array of size max alleles for all sites ret = tsk_bit_array_init(&mut_samples, num_samples, max_muts_len); if (ret != 0) { goto out; @@ -2387,103 +2433,111 @@ get_mutation_samples( if (ret != 0) { goto out; } - + get_all_samples_bits(&all_samples_bits, num_samples); ret = tsk_tree_init(&tree, ts, TSK_NO_SAMPLE_COUNTS); if (ret != 0) { goto out; } - // A future improvement could get a union of all sample sets - // instead of all samples - get_all_samples_bits(&all_samples_bits, num_samples); - - // Traverse down each tree, recording all samples below each mutation. We perform one - // preorder traversal per mutation. - mut_offset = 0; - for (ret = tsk_tree_first(&tree); ret == TSK_TREE_OK; ret = tsk_tree_next(&tree)) { + // For each mutation within each site, perform one preorder traversal to gather + // the samples under each mutation's node. + site_offset = 0; + for (site_idx = 0; site_idx < n_sites; site_idx++) { + tsk_treeseq_get_site(ts, sites[site_idx], &site); + ret = tsk_tree_seek(&tree, site.position, 0); + if (ret != 0) { + goto out; + } tmp_nodes = tsk_realloc(nodes, tsk_tree_get_size_bound(&tree) * sizeof(*nodes)); if (tmp_nodes == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } nodes = tmp_nodes; - for (s = 0; s < tree.sites_length; s++) { - site = &tree.sites[s]; - tsk_bit_array_get_row(allele_samples, mut_offset, &out_row); - tsk_bit_array_add(&out_row, &all_samples_bits); - // Zero out results before the start of each iteration - tsk_memset(mut_samples.data, 0, - mut_samples.size * max_muts_len * sizeof(tsk_bit_array_value_t)); - for (m = 0; m < site->mutations_length; m++) { - tsk_bit_array_get_row(&mut_samples, m, &mut_samples_row); - node = site->mutations[m].node; - ret = tsk_tree_preorder_from(&tree, node, nodes, &num_nodes); - if (ret != 0) { - goto out; - } - for (n = 0; n < num_nodes; n++) { - node = nodes[n]; - if (flags[node] & TSK_NODE_IS_SAMPLE) { - tsk_bit_array_add_bit( - &mut_samples_row, (tsk_bit_array_value_t) node); - } + + tsk_bit_array_get_row(allele_samples, site_offset, &out_row); + tsk_bit_array_add(&out_row, &all_samples_bits); + + // Zero out results before the start of each iteration + tsk_memset(mut_samples.data, 0, + mut_samples.size * max_muts_len * sizeof(tsk_bit_array_value_t)); + for (m = 0; m < site.mutations_length; m++) { + tsk_bit_array_get_row(&mut_samples, m, &mut_samples_row); + node = site.mutations[m].node; + ret = tsk_tree_preorder_from(&tree, node, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + for (n = 0; n < num_nodes; n++) { + node = nodes[n]; + if (flags[node] & TSK_NODE_IS_SAMPLE) { + tsk_bit_array_add_bit( + &mut_samples_row, (tsk_bit_array_value_t) node); } - mut_offset++; } - mut_offset++; // One more for the ancestral allele - get_allele_samples(site, &mut_samples, &out_row, &(num_alleles[site->id])); } + site_offset += site.mutations_length + 1; + get_allele_samples(&site, &mut_samples, &out_row, &(num_alleles[site_idx])); } - // if adding code below, check ret before continuing +// if adding code below, check ret before continuing out: tsk_safe_free(nodes); tsk_tree_free(&tree); tsk_bit_array_free(&mut_samples); tsk_bit_array_free(&all_samples_bits); - return ret; + return ret == TSK_TREE_OK ? 0 : ret; } static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, - const double *TSK_UNUSED(left_window), const double *TSK_UNUSED(right_window), - tsk_flags_t options, tsk_size_t *result_size, double **result) + sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, + const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { + int ret = 0; - tsk_bit_array_t allele_samples; - tsk_bit_array_t site_a_state, site_b_state; - tsk_size_t inner, result_offset, inner_offset, a_offset, b_offset; - tsk_size_t site_a, site_b; + tsk_bit_array_t allele_samples, c_state, r_state; bool polarised = false; - const tsk_size_t num_sites = self->tables->sites.num_rows; + tsk_id_t *sites; + tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx; + double *result_row; const tsk_size_t num_samples = self->num_samples; - const tsk_size_t max_alleles = self->tables->mutations.num_rows + num_sites; - tsk_size_t *num_alleles = tsk_malloc(num_sites * sizeof(*num_alleles)); - const tsk_size_t *restrict site_muts_len = self->site_mutations_length; + tsk_size_t *num_alleles = NULL, *site_offsets = NULL; + tsk_size_t result_row_len = n_cols * result_dim; tsk_memset(&allele_samples, 0, sizeof(allele_samples)); - if (num_alleles == NULL) { + sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites)); + row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx)); + col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx)); + if (sites == NULL || row_idx == NULL || col_idx == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } + get_site_row_col_indices( + n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx); - ret = tsk_bit_array_init(&allele_samples, num_samples, max_alleles); - if (ret != 0) { + // We rely on n_sites to allocate these arrays, they're initialized to NULL for safe + // deallocation if the previous allocation fails + num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); + site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); + if (num_alleles == NULL || site_offsets == NULL) { + ret = TSK_ERR_NO_MEMORY; goto out; } - ret = get_mutation_samples(self, num_alleles, &allele_samples); + + n_alleles = 0; + for (s = 0; s < n_sites; s++) { + site_offsets[s] = n_alleles; + n_alleles += self->site_mutations_length[sites[s]] + 1; + } + ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles); if (ret != 0) { goto out; } - - // Number of pairs w/ replacement (sites) - *result_size = (num_sites * (1 + num_sites)) / 2U; - *result = tsk_calloc(*result_size * result_dim, sizeof(**result)); - - if (result == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = get_mutation_samples(self, sites, n_sites, num_alleles, &allele_samples); + if (ret != 0) { goto out; } @@ -2491,34 +2545,28 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, polarised = true; } - inner = 0; - a_offset = 0; - b_offset = 0; - inner_offset = 0; - result_offset = 0; - // TODO: implement windows! - for (site_a = 0; site_a < num_sites; site_a++) { - b_offset = inner_offset; - for (site_b = inner; site_b < num_sites; site_b++) { - tsk_bit_array_get_row(&allele_samples, a_offset, &site_a_state); - tsk_bit_array_get_row(&allele_samples, b_offset, &site_b_state); - ret = compute_general_two_site_stat_result(&site_a_state, &site_b_state, - num_alleles[site_a], num_alleles[site_b], num_samples, state_dim, + // For each row/column pair, fill in the sample set in the result matrix. + for (r = 0; r < n_rows; r++) { + result_row = GET_2D_ROW(result, result_row_len, r); + for (c = 0; c < n_cols; c++) { + tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state); + tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state); + ret = compute_general_two_site_stat_result(&r_state, &c_state, + num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim, sample_sets, result_dim, f, f_params, norm_f, polarised, - &((*result)[result_offset])); + &(result_row[c * result_dim])); if (ret != 0) { goto out; } - result_offset += result_dim; - b_offset += site_muts_len[site_b] + 1; } - a_offset += site_muts_len[site_a] + 1; - inner_offset += site_muts_len[site_a] + 1; - inner++; } out: + tsk_safe_free(sites); + tsk_safe_free(row_idx); + tsk_safe_free(col_idx); tsk_safe_free(num_alleles); + tsk_safe_free(site_offsets); tsk_bit_array_free(&allele_samples); return ret; } @@ -2558,14 +2606,43 @@ sample_sets_to_bit_array(const tsk_treeseq_t *self, const tsk_size_t *sample_set return ret; } +static int +check_sites(const tsk_id_t *sites, tsk_size_t num_sites, tsk_size_t num_site_rows) +{ + int ret = 0; + tsk_size_t i; + + if (sites == NULL || num_sites == 0) { + ret = TSK_ERR_BAD_SITE_POSITION; // TODO: error should be no sites? + goto out; + } + + for (i = 0; i < num_sites - 1; i++) { + if (sites[i] < 0 || sites[i] >= (tsk_id_t) num_site_rows) { + ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + goto out; + } + if (sites[i] >= sites[i + 1]) { + // TODO: this checks no repeats, but error is ambiguous + ret = TSK_ERR_UNSORTED_SITES; + goto out; + } + } + // check the last value + if (sites[i] < 0 || sites[i] >= (tsk_id_t) num_site_rows) { + ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + goto out; + } +out: + return ret; +} + static int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, tsk_size_t TSK_UNUSED(num_left_windows), - const double *left_windows, tsk_size_t TSK_UNUSED(num_right_windows), - const double *right_windows, tsk_flags_t options, tsk_size_t *result_size, - double **result) + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + tsk_size_t out_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { // TODO: generalize this function if we ever decide to do weighted two_locus stats. // We only implement count stats and therefore we don't handle weights. @@ -2601,8 +2678,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl // goto out; // } - tsk_bug_assert(left_windows == NULL && right_windows == NULL); - ret = tsk_treeseq_check_sample_sets( self, num_sample_sets, sample_set_sizes, sample_sets); if (ret != 0) { @@ -2615,9 +2690,17 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl } if (stat_site) { + ret = check_sites(row_sites, out_rows, self->tables->sites.num_rows); + if (ret != 0) { + goto out; + } + ret = check_sites(col_sites, out_cols, self->tables->sites.num_rows); + if (ret != 0) { + goto out; + } ret = tsk_treeseq_two_site_count_stat(self, state_dim, &sample_sets_bits, - result_dim, f, &f_params, norm_f, left_windows, right_windows, options, - result_size, result); + result_dim, f, &f_params, norm_f, out_rows, row_sites, out_cols, col_sites, + options, result); } else { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; } @@ -3451,16 +3534,14 @@ D_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3490,15 +3571,13 @@ D2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3534,15 +3613,13 @@ r2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, num_rows, + row_sites, num_cols, col_sites, options, result); } static int @@ -3576,16 +3653,14 @@ D_prime_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_hap_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3621,16 +3696,14 @@ r_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3661,15 +3734,13 @@ Dz_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3697,15 +3768,13 @@ pi2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } /*********************************** diff --git a/c/tskit/trees.h b/c/tskit/trees.h index dbc870ad2f..2faa3c95c6 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1046,41 +1046,39 @@ int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, + tsk_size_t num_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result); + int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py new file mode 100644 index 0000000000..40ccd8f480 --- /dev/null +++ b/python/tests/test_ld_matrix.py @@ -0,0 +1,833 @@ +import io +from itertools import combinations_with_replacement +from itertools import permutations +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generator +from typing import List +from typing import Tuple + +import numpy as np +import pytest + +import tskit + + +class BitSet: + """BitSet object, which stores values in arrays of unsigned integers. + The rows represent all possible values a bit can take, and the rows + represent each item that can be stored in the array. + + :param num_bits: The number of values that a single row can contain. + :param length: The number of rows. + """ + + DTYPE = np.uint32 # Data type to be stored in the bitset + CHUNK_SIZE = DTYPE(32) # Size of integer field to store the data in + + def __init__(self: "BitSet", num_bits: int, length: int) -> None: + self.row_len = num_bits // self.CHUNK_SIZE + self.row_len += 1 if num_bits % self.CHUNK_SIZE else 0 + self.row_len = int(self.row_len) + self.data = np.zeros(self.row_len * length, dtype=self.DTYPE) + + def intersect( + self: "BitSet", self_row: int, other: "BitSet", other_row: int, out: "BitSet" + ) -> None: + """Intersect a row from the current array instance with a row from + another BitSet and store it in an output bit array of length 1. + + NB: we don't specify the row in the output array, it is expected + to be length 1. + + :param self_row: Row from the current array instance to be intersected. + :param other: Other BitSet to intersect with. + :param other_row: Row from the other BitSet instance. + :param out: BitArray to store the result. + """ + self_offset = self_row * self.row_len + other_offset = other_row * self.row_len + + for i in range(self.row_len): + out.data[i] = self.data[i + self_offset] & other.data[i + other_offset] + + def difference( + self: "BitSet", self_row: int, other: "BitSet", other_row: int + ) -> None: + """Take the difference between the current array instance and another + array instance. Store the result in the specified row of the current + instance. + + :param self_row: Row from the current array from which to subtract. + :param other: Other BitSet to subtract from the current instance. + :param other_row: Row from the other BitSet instance. + """ + self_offset = self_row * self.row_len + other_offset = other_row * self.row_len + + for i in range(self.row_len): + self.data[i + self_offset] &= ~(other.data[i + other_offset]) + + def union(self: "BitSet", self_row: int, other: "BitSet", other_row: int) -> None: + """Take the union between the current array instance and another + array instance. Store the result in the specified row of the current + instance. + + :param self_row: Row from the current array with which to union. + :param other: Other BitSet to union with the current instance. + :param other_row: Row from the other BitSet instance. + """ + self_offset = self_row * self.row_len + other_offset = other_row * self.row_len + + for i in range(self.row_len): + self.data[i + self_offset] |= other.data[i + other_offset] + + def add(self: "BitSet", row: int, bit: int) -> None: + """Add a single bit to the row of a bit array + + :param row: Row to be modified. + :param bit: Bit to be added. + """ + offset = row * self.row_len + i = bit // self.CHUNK_SIZE + self.data[i + offset] |= self.DTYPE(1) << (bit - (self.CHUNK_SIZE * i)) + + def get_items(self: "BitSet", row: int) -> Generator[int, None, None]: + """Get the items stored in the row of a bitset + + :param row: Row from the array to list from. + :returns: A generator of integers stored in the array. + """ + offset = row * self.row_len + for i in range(self.row_len): + for item in range(self.CHUNK_SIZE): + if self.data[i + offset] & (self.DTYPE(1) << item): + yield item + (i * self.CHUNK_SIZE) + + def contains(self: "BitSet", row: int, bit: int) -> bool: + """Test if a bit is contained within a bit array row + + :param row: Row to test. + :param bit: Bit to check. + :returns: True if the bit is set in the row, else false. + """ + i = bit // self.CHUNK_SIZE + offset = row * self.row_len + return bool( + self.data[i + offset] & (self.DTYPE(1) << (bit - (self.CHUNK_SIZE * i))) + ) + + def count(self: "BitSet", row: int) -> int: + """Count all of the set bits in a specified row. Uses a SWAR + algorithm to count in parallel with a constant number (12) of operations. + + NB: we have to cast all values to our unsigned dtype to avoid type promotion + + Details here: + # https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + + :param row: Row to count. + :returns: Count of all of the set bits. + """ + count = 0 + offset = row * self.row_len + D = self.DTYPE + + for i in range(offset, offset + self.row_len): + v = self.data[i] + v = v - ((v >> D(1)) & D(0x55555555)) + v = (v & D(0x33333333)) + ((v >> D(2)) & D(0x33333333)) + # this operation relies on integer overflow + with np.errstate(over="ignore"): + count += ((v + (v >> D(4)) & D(0xF0F0F0F)) * D(0x1010101)) >> D(24) + + return count + + def count_naive(self: "BitSet", row: int) -> int: + """Naive counting algorithm implementing the same functionality as the count + method. Useful for testing correctness, uses the same number of operations + as set bits. + + Details here: + # https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetNaive + + :param row: Row to count. + :returns: Count of all of the set bits. + """ + count = 0 + offset = row * self.row_len + + for i in range(offset, offset + self.row_len): + v = self.data[i] + while v: + v &= v - self.DTYPE(1) + count += self.DTYPE(1) + return count + + +def norm_hap_weighted( + state_dim: int, + hap_weights: np.ndarray, + n_a: int, + n_b: int, + result: np.ndarray, + params: Dict[str, Any], +) -> None: + """Create a vector of normalizing coefficients, length of the number of + sample sets. In this normalization strategy, we weight each allele's + statistic by the proportion of the haplotype present. + + :param state_dim: Number of sample sets. + :param hap_weights: Proportion of each two-locus haplotype. + :param n_a: Number of alleles at the A locus. + :param n_b: Number of alleles at the B locus. + :param result: Result vector to store the normalizing coefficients in. + :param params: Params of summary function. + """ + del n_a, n_b # handle unused params + sample_set_sizes = params["sample_set_sizes"] + for k in range(state_dim): + n = sample_set_sizes[k] + result[k] = hap_weights[0, k] / n + + +def norm_total_weighted( + state_dim: int, + hap_weights: np.ndarray, + n_a: int, + n_b: int, + result: np.ndarray, + params: Dict[str, Any], +) -> None: + """Create a vector of normalizing coefficients, length of the number of + sample sets. In this normalization strategy, we weight each allele's + statistic by the product of the allele frequencies + + :param state_dim: Number of sample sets. + :param hap_weights: Proportion of each two-locus haplotype. + :param n_a: Number of alleles at the A locus. + :param n_b: Number of alleles at the B locus. + :param result: Result vector to store the normalizing coefficients in. + :param params: Params of summary function. + """ + del hap_weights, params # handle unused params + for k in range(state_dim): + result[k] = 1 / (n_a * n_b) + + +def check_sites(sites, max_sites): + """Validate the specified site ids. + + We require that sites are: + + 1) Within the boundaries of available sites in the tree sequence + 2) Sorted + 3) Non-repeating + + Raises an exception if any error is found. + + :param sites: 1d array of sites to validate. + :param max_sites: Number of sites in the tree sequence, the upper + bound value for site ids. + """ + if sites is None or len(sites) == 0: + raise ValueError("No sites provided") + i = 0 + for i in range(len(sites) - 1): + if sites[i] < 0 or sites[i] >= max_sites: + raise ValueError(f"Site out of bounds: {sites[i]}") + if sites[i] >= sites[i + 1]: + raise ValueError(f"Sites not sorted: {sites[i], sites[i + 1]}") + if sites[-1] < 0 or sites[-1] >= max_sites: + raise ValueError(f"Site out of bounds: {sites[i + 1]}") + + +def get_site_row_col_indices( + row_sites: List[int], col_sites: List[int] +) -> Tuple[List[int], List[int], List[int]]: + """Co-iterate over the row and column sites, keeping a sorted union of + site values and an index into the unique list of sites for both the row + and column sites. This function produces a list of sites of interest and + row and column indexes into this list of sites. + + NB: This routine requires that the site lists are sorted and deduplicated. + + :param row_sites: List of sites that will be represented in the output + matrix rows. + :param col_sites: List of sites that will be represented in the output + matrix columns. + :returns: Tuple of lists of sites, row, and column indices. + """ + r = 0 + c = 0 + s = 0 + sites = [] + col_idx = [] + row_idx = [] + + while r < len(row_sites) and c < len(col_sites): + if row_sites[r] < col_sites[c]: + sites.append(row_sites[r]) + row_idx.append(s) + s += 1 + r += 1 + elif row_sites[r] > col_sites[c]: + sites.append(col_sites[c]) + col_idx.append(s) + s += 1 + c += 1 + else: + sites.append(row_sites[r]) + row_idx.append(s) + col_idx.append(s) + s += 1 + r += 1 + c += 1 + while r < len(row_sites): + sites.append(row_sites[r]) + row_idx.append(s) + s += 1 + r += 1 + while c < len(col_sites): + sites.append(col_sites[c]) + col_idx.append(s) + s += 1 + c += 1 + + return sites, row_idx, col_idx + + +def get_all_samples_bits(num_samples: int) -> BitSet: + """Get the bits for all samples in the tree sequence. This is achieved + by creating a length 1 bitset and adding every sample's bit to it. + + :param num_samples: Number of samples contained in the tree sequence. + :returns: Length 1 BitSet containing all samples in the tree sequence. + """ + all_samples = BitSet(num_samples, 1) + for i in range(num_samples): + all_samples.add(0, i) + return all_samples + + +def get_allele_samples( + site: tskit.Site, site_offset: int, mut_samples: BitSet, allele_samples: BitSet +) -> int: + """Given a BitSet that has been arranged so that we have every sample under + a given mutation's node, create the final output where we know which samples + should belong under each mutation, considering the mutation's parentage, + back mutations, and ancestral state. + + To this end, we iterate over each mutation and store the samples under the + focal mutation in the output BitSet (allele_samples). Then, we check the + parent of the focal mutation (either a mutation or the ancestral allele), + and we subtract the samples in the focal mutation from the parent allele's + samples. + + :param site: Focal site for which to adjust mutation data. + :param site_offset: Offset into allele_samples for our focal site. + :param mut_samples: BitSet containing the samples under each mutation in the + focal site. + :param allele_samples: Output BitSet, initially passed in with all of the + tree sequence samples set in the ancestral allele + state. + :returns: number of alleles actually encountered (adjusting for back-mutation). + """ + alleles = [] + num_alleles = 1 + alleles.append(site.ancestral_state) + + for m, mut in enumerate(site.mutations): + try: + allele = alleles.index(mut.derived_state) + except ValueError: + allele = len(alleles) + alleles.append(mut.derived_state) + num_alleles += 1 + allele_samples.union(allele + site_offset, mut_samples, m) + # now to find the parent allele from which we must subtract + alt_allele_state = site.ancestral_state + if mut.parent != tskit.NULL: + parent_mut = site.mutations[mut.parent - site.mutations[0].id] + alt_allele_state = parent_mut.derived_state + alt_allele = alleles.index(alt_allele_state) + # subtract focal allele's samples from the alt allele + allele_samples.difference( + alt_allele + site_offset, allele_samples, allele + site_offset + ) + + return num_alleles + + +def get_mutation_samples( + ts: tskit.TreeSequence, sites: List[int] +) -> Tuple[np.ndarray, np.ndarray, BitSet]: + """For a given set of sites, generate a BitSet of all samples posessing + each allelic state for each site. This includes the ancestral state, along + with any mutations contained in the site. + + We achieve this goal by starting at the tree containing the first site in + our list, then we walk along each tree until we've encountered the last + tree containing the last site in our list. Along the way, we perform a + preorder traversal from the node of each mutation in a given site, storing + the samples under that particular node. After we've stored all of the samples + for each allele at a site, we adjust each allele's samples by removing + samples that have a different allele at a child mutation down the tree (see + get_allele_samples for more details). + + We also gather some ancillary data while we iterate over the sites: the + number of alleles for each site, and the offset of each site. The number of + alleles at each site includes the count of mutations + the ancestral allele. + The offeset for each site indicates how many array entries we must skip (ie + how many alleles exist before a specific site's entry) in order to address + the data for a given site. + + :param ts: Tree sequence to gather data from. + :param sites: Subset of sites to consider when gathering data. + :returns: Tuple of the number of alleles per site, site offsets, and the + BitSet of all samples in each allelic state. + """ + num_alleles = np.zeros(len(sites), dtype=np.uint64) + site_offsets = np.zeros(len(sites), dtype=np.uint64) + all_samples = get_all_samples_bits(ts.num_samples) + allele_samples = BitSet( + ts.num_samples, sum(len(ts.site(i).mutations) + 1 for i in sites) + ) + + site_offset = 0 + site_idx = 0 + for site_idx, site_id in enumerate(sites): + site = ts.site(site_id) + tree = ts.at(site.position) + # initialize the ancestral allele with all samples + allele_samples.union(site_offset, all_samples, 0) + # store samples for each mutation in mut_samples + mut_samples = BitSet(ts.num_samples, len(site.mutations)) + for m, mut in enumerate(site.mutations): + for node in tree.preorder(mut.node): + if ts.node(node).is_sample(): + mut_samples.add(m, node) + # account for mutation parentage, subtract samples from mutation parents + num_alleles[site_idx] = get_allele_samples( + site, site_offset, mut_samples, allele_samples + ) + # increment the offset for ancestral + mutation alleles + site_offsets[site_idx] = site_offset + site_offset += len(site.mutations) + 1 + + return num_alleles, site_offsets, allele_samples + + +def compute_general_two_site_stat_result( + row_site_offset: int, + col_site_offset: int, + num_row_alleles: int, + num_col_alleles: int, + num_samples: int, + allele_samples: BitSet, + state_dim: int, + sample_sets: BitSet, + func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], + norm_func: Callable[[int, np.ndarray, int, int, np.ndarray, Dict[str, Any]], None], + params: Dict[str, Any], + polarised: bool, + result: np.ndarray, +) -> None: + """For a given pair of sites, compute the summary statistic for the allele + frequencies for each allelic state of the two pairs. + + :param row_site_offset: Offset of the row site's data in the allele_samples. + :param row_site_offset: Offset of the col site's data in the allele_samples. + :param num_row_alleles: Number of alleles in the row site. + :param num_col_alleles: Number of alleles in the col site. + :param num_samples: Number of samples in tree sequence. + :param allele_samples: BitSet containing the samples with each allelic state + for each site of interest. + :param state_dim: Number of sample sets. + :param sample_sets: BitSet of sample sets to be intersected with the samples + contained within each allele. + :param func: Summary function used to compute each two-locus statistic. + :param norm_func: Function used to generate the normalization coefficients + for each statistic. + :param params: Parameters to pass to the norm and summary function. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :param result: Vector of the results matrix to populate. We will produce one + value per sample set, hence the vector of length state_dim. + """ + ss_A_samples = BitSet(num_samples, 1) + ss_B_samples = BitSet(num_samples, 1) + ss_AB_samples = BitSet(num_samples, 1) + AB_samples = BitSet(num_samples, 1) + weights = np.zeros((3, state_dim), np.float64) + norm = np.zeros(state_dim, np.float64) + result_tmp = np.zeros(state_dim, np.float64) + + polarised_val = 1 if polarised else 0 + + for mut_a in range(polarised_val, num_row_alleles): + a = int(mut_a + row_site_offset) + for mut_b in range(polarised_val, num_col_alleles): + b = int(mut_b + col_site_offset) + allele_samples.intersect(a, allele_samples, b, AB_samples) + for k in range(state_dim): + allele_samples.intersect(a, sample_sets, k, ss_A_samples) + allele_samples.intersect(b, sample_sets, k, ss_B_samples) + AB_samples.intersect(0, sample_sets, k, ss_AB_samples) + + w_AB = ss_AB_samples.count(0) + w_A = ss_A_samples.count(0) + w_B = ss_B_samples.count(0) + + weights[0, k] = w_AB + weights[1, k] = w_A - w_AB # w_Ab + weights[2, k] = w_B - w_AB # w_aB + + func(state_dim, weights, result_tmp, params) + + norm_func( + state_dim, + weights, + num_row_alleles - polarised_val, + num_col_alleles - polarised_val, + norm, + params, + ) + + for k in range(state_dim): + result[k] += result_tmp[k] * norm[k] + + +def two_site_count_stat( + ts: tskit.TreeSequence, + func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], + norm_func: Callable[[int, np.ndarray, int, int, np.ndarray, Dict[str, Any]], None], + num_sample_sets: int, + sample_set_sizes: np.ndarray, + sample_sets: BitSet, + row_sites: List[int], + col_sites: List[int], + polarised: bool, +) -> np.ndarray: + """Outer function that generates the high-level intermediates used in the + computation of our two-locus statistics. First, we compute the row and + column indices for our unique list of sites, then we get each sample for + each allele in our list of specified sites. + + With those intermediates in hand, we iterate over the row and column indices + to compute comparisons between each of the specified lists of sites. We pass + a vector of results to the computation, which will compute a single result + for each sample set, inserting that into our result matrix. + + :param ts: Tree sequence to gather data from. + :param func: Function used to compute each two-locus statistic. + :param norm_func: Function used to generate the normalization coefficients + for each statistic. + :param num_sample_sets: Number of sample sets that we will consider. + :param sample_set_sizes: Number of samples in each sample set. + :param sample_sets: BitSet of samples to compute stats for. We will only + consider these samples in our computations, resulting + in stats that are computed on subsets of the samples + on the tree sequence. + :param row_sites: Sites contained in the rows of the output matrix. + :param col_sites: Sites contained in the columns of the output matrix. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :returns: 3D array of results, dimensions (sample_sets, row_sites, col_sites). + """ + state_dim = len(sample_set_sizes) + params = {"sample_set_sizes": sample_set_sizes} + result = np.zeros( + (num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64 + ) + + sites, row_idx, col_idx = get_site_row_col_indices(row_sites, col_sites) + num_alleles, site_offsets, allele_samples = get_mutation_samples(ts, sites) + + for row, row_site in enumerate(row_idx): + for col, col_site in enumerate(col_idx): + compute_general_two_site_stat_result( + site_offsets[row_site], + site_offsets[col_site], + num_alleles[row_site], + num_alleles[col_site], + ts.num_samples, + allele_samples, + state_dim, + sample_sets, + func, + norm_func, + params, + polarised, + result[:, row, col], + ) + + return result + + +def sample_sets_to_bit_array( + ts: tskit.TreeSequence, sample_sets: List[List[int]] +) -> Tuple[np.ndarray, BitSet]: + """Convert the list of sample ids to a bit array. This function takes + sample identifiers and maps them to their enumerated integer values, then + stores these values in a bit array. We produce a BitArray and a numpy + array of integers that specify how many samples there are in each sample set. + + NB: this function's type signature is of type integer, but I believe this + could be expanded to Any, currently untested so the integer + specification remains. + + :param ts: Tree sequence to gather data from. + :param sample_sets: List of sample identifiers to store in bit array. + :returns: Tuple containing numpy array of sample set sizes and the sample + set BitSet. + """ + sample_sets_bits = BitSet(ts.num_samples, len(sample_sets)) + sample_index_map = -np.ones(ts.num_nodes, dtype=np.int32) + sample_set_sizes = np.zeros(len(sample_sets), dtype=np.uint64) + + for i, sample in enumerate(ts.samples()): + sample_index_map[sample] = i + + for k, sample_set in enumerate(sample_sets): + sample_set_sizes[k] = len(sample_set) + for sample in sample_set: + sample_index = sample_index_map[sample] + if sample_index == tskit.NULL: + raise ValueError(f"Sample out of bounds: {sample}") + if sample_sets_bits.contains(k, sample_index): + raise ValueError(f"Duplicate sample detected: {sample}") + sample_sets_bits.add(k, sample_index) + + return sample_set_sizes, sample_sets_bits + + +def two_locus_count_stat( + ts, + summary_func, + norm_func, + polarised, + sites=None, + sample_sets=None, +): + """Outer wrapper for two site general stat functionality. Perform some input + validation, get the site index and allele state, then compute the LD matrix. + + TODO: implement mode switching for branch stats + + :param ts: Tree sequence to gather data from. + :param summary_func: Function used to compute each two-locus statistic. + :param norm_func: Function used to generate the normalization coefficients + for each statistic. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :param sites: List of two lists containing [row_sites, column_sites]. + :param sample_sets: List of lists of samples to compute stats for. We will + only consider these samples in our computations, + resulting in stats that are computed on subsets of the + samples on the tree sequence. + :returns: 3d numpy array containing LD for (sample_set,row_site,column_site) + unless one or no sample sets are specified, then 2d array + containing LD for (row_site,column_site). + """ + if sample_sets is None: + sample_sets = [ts.samples()] + if sites is None: + sites = [np.arange(ts.num_sites), np.arange(ts.num_sites)] + else: + if len(sites) != 2: + raise ValueError( + f"Sites must be a length 2 list, got a length {len(sites)} list" + ) + sites[0] = np.asarray(sites[0]) + sites[1] = np.asarray(sites[1]) + + row_sites, col_sites = sites + check_sites(row_sites, ts.num_sites) + check_sites(col_sites, ts.num_sites) + + ss_sizes, ss_bits = sample_sets_to_bit_array(ts, sample_sets) + + result = two_site_count_stat( + ts, + summary_func, + norm_func, + len(ss_sizes), + ss_sizes, + ss_bits, + sites[0], + sites[1], + polarised, + ) + + # If there is one sample set, return a 2d numpy array of row/site LD + if len(sample_sets) == 1: + return result.reshape(result.shape[1:3]) + return result + + +def r2_summary_func( + state_dim: int, state: np.ndarray, result: np.ndarray, params: Dict[str, Any] +) -> None: + """Summary function for the r2 statistic. We first compute the proportion of + AB, A, and B haplotypes, then we compute the r2 statistic, storing the outputs + in the result vector, one entry per sample set. + + :param state_dim: Number of sample sets. + :param state: Counts of 3 haplotype configurations for each sample set. + :param result: Vector of length state_dim to store the results in. + :param params: Parameters for the summary function. + """ + sample_set_sizes = params["sample_set_sizes"] + for k in range(state_dim): + n = sample_set_sizes[k] + p_AB = state[0, k] / n + p_Ab = state[1, k] / n + p_aB = state[2, k] / n + + p_A = p_AB + p_Ab + p_B = p_AB + p_aB + + D = p_AB - (p_A * p_B) + denom = p_A * p_B * (1 - p_A) * (1 - p_B) + + if denom == 0 and D == 0: + result[k] = 0 + else: + result[k] = (D * D) / denom + + +def get_paper_ex_ts(): + """Generate the tree sequence example from the tskit paper + + Data taken from the tests: + https://github.com/tskit-dev/tskit/blob/61a844a/c/tests/testlib.c#L55-L96 + + :returns: Tree sequence + """ + nodes = """\ + is_sample time population individual + 1 0 -1 0 + 1 0 -1 0 + 1 0 -1 1 + 1 0 -1 1 + 0 0.071 -1 -1 + 0 0.090 -1 -1 + 0 0.170 -1 -1 + 0 0.202 -1 -1 + 0 0.253 -1 -1 + """ + + edges = """\ + left right parent child + 2 10 4 2 + 2 10 4 3 + 0 10 5 1 + 0 2 5 3 + 2 10 5 4 + 0 7 6 0,5 + 7 10 7 0,5 + 0 2 8 2,6 + """ + + sites = """\ + position ancestral_state + 1 0 + 4.5 0 + 8.5 0 + """ + + mutations = """\ + site node derived_state + 0 2 1 + 1 0 1 + 2 5 1 + """ + + individuals = """\ + flags location parents + 0 0.2,1.5 -1,-1 + 0 0.0,0.0 -1,-1 + """ + + return tskit.load_text( + nodes=io.StringIO(nodes), + edges=io.StringIO(edges), + sites=io.StringIO(sites), + individuals=io.StringIO(individuals), + mutations=io.StringIO(mutations), + strict=False, + ) + + +# fmt:off +# true r2 values for the tree sequence from the tskit paper +PAPER_EX_TRUTH_MATRIX = np.array( + [[1.0, 0.11111111, 0.11111111], # noqa: E241 + [0.11111111, 1.0, 1.0], # noqa: E241 + [0.11111111, 1.0, 1.0]] # noqa: E241 +) +# fmt:on + + +def get_all_site_partitions(n): + """Generate all partitions for square matricies, then combine with replacement + and return all possible pairs of all partitions. + + TODO: only works for square matricies, would need to generate two lists of + partitions to get around this + + :param n: length of one dimension of the !square! matrix. + :returns: combinations of partitions. + """ + parts = [] + for part in tskit.combinatorics.rule_asc(3): + for g in set(permutations(part, len(part))): + p = [] + i = iter(range(n)) + for item in g: + p.append([next(i) for _ in range(item)]) + parts.append(p) + combos = [] + for a, b in combinations_with_replacement({tuple(j) for i in parts for j in i}, 2): + combos.append((a, b)) + combos.append((b, a)) + combos = [[list(a), list(b)] for a, b in set(combos)] + return combos + + +def assert_slice_allclose(a, b): + """Provide two lists of sites to the general stat function, then check to + see if the subset matches the slice out of the truth matrix. Raise if + arrays not close. + + :param a: row sites. + :param b: column sites. + """ + ts = get_paper_ex_ts() + np.testing.assert_allclose( + two_locus_count_stat( + ts, r2_summary_func, norm_hap_weighted, False, sites=[a, b] + ), + PAPER_EX_TRUTH_MATRIX[a[0] : a[-1] + 1, b[0] : b[-1] + 1], + ) + + +@pytest.mark.parametrize( + # Generate all partitions of the LD matrix that, then pass into test_subset + "partition", + get_all_site_partitions(len(PAPER_EX_TRUTH_MATRIX)), +) +def test_subset(partition): + """Given a partition of the truth matrix, check that we can successfully + compute the LD matrix for that given partition, effectively ensuring that + our handling of site subsets is correct. + + :param partition: length 2 list of [row_sites, column_sites]. This is a + pytest fixture for a parametrized function. + """ + a, b = partition + print(a, b) + assert_slice_allclose(a, b)