diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index c4221ee13e..8080b884a8 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -319,9 +319,9 @@ verify_pair_coalescence_counts(tsk_treeseq_t *ts, tsk_flags_t options) tsk_id_t sample_sets[n]; tsk_size_t sample_set_sizes[P]; tsk_id_t index_tuples[2 * I]; - tsk_id_t node_output_map[N]; + tsk_id_t node_bin_map[N]; tsk_size_t dim = T * N * I; - double C1[dim]; + double C[dim]; tsk_size_t i, j, k; for (i = 0; i < n; i++) { @@ -343,53 +343,67 @@ verify_pair_coalescence_counts(tsk_treeseq_t *ts, tsk_flags_t options) } } + /* test various bin assignments */ for (i = 0; i < N; i++) { - node_output_map[i] = (tsk_id_t) i; + node_bin_map[i] = ((tsk_id_t) i) % 8; } + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, T, breakpoints, 8, node_bin_map, options, C); + CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, T, breakpoints, N, node_output_map, options, C1); + for (i = 0; i < N; i++) { + node_bin_map[i] = i < N / 2 ? ((tsk_id_t) i) : TSK_NULL; + } + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, T, breakpoints, N / 2, node_bin_map, options, C); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (i = 0; i < N; i++) { + node_bin_map[i] = (tsk_id_t) i; + } + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, T, breakpoints, N, node_bin_map, options, C); CU_ASSERT_EQUAL_FATAL(ret, 0); - /* TODO: test against naive pairs per node per tree */ + /* TODO: compare against naive pairs per node per tree */ /* cover errors */ double bad_breakpoints[2] = { breakpoints[1], 0.0 }; - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, 1, bad_breakpoints, N, node_output_map, options, C1); + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, 1, bad_breakpoints, N, node_bin_map, options, C); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); index_tuples[0] = (tsk_id_t) P; - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, 1, breakpoints, N, node_output_map, options, C1); + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, 1, breakpoints, N, node_bin_map, options, C); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); index_tuples[0] = 0; tsk_size_t tmp = sample_set_sizes[0]; sample_set_sizes[0] = 0; - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, 1, breakpoints, N, node_output_map, options, C1); + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, 1, breakpoints, N, node_bin_map, options, C); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EMPTY_SAMPLE_SET); sample_set_sizes[0] = tmp; sample_sets[1] = 0; - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, 1, breakpoints, N, node_output_map, options, C1); + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, 1, breakpoints, N, node_bin_map, options, C); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); sample_sets[1] = 1; - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, 1, breakpoints, N - 1, node_output_map, options, C1); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_OUTPUT_MAP_DIM); + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, 1, breakpoints, N - 1, node_bin_map, options, C); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_BIN_MAP_DIM); - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, 1, breakpoints, 0, node_output_map, options, C1); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_OUTPUT_MAP_DIM); + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, 1, breakpoints, 0, node_bin_map, options, C); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_BIN_MAP_DIM); - node_output_map[0] = -2; - ret = tsk_treeseq_pair_coalescence_stat(ts, P, sample_set_sizes, sample_sets, I, - index_tuples, 1, breakpoints, N, node_output_map, options, C1); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_OUTPUT_MAP); - node_output_map[0] = 0; + node_bin_map[0] = -2; + ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I, + index_tuples, 1, breakpoints, N, node_bin_map, options, C); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_BIN_MAP); + node_bin_map[0] = 0; } typedef struct { @@ -2880,8 +2894,6 @@ test_pair_coalescence_counts(void) tsk_treeseq_t ts; tsk_treeseq_from_text(&ts, 100, nonbinary_ex_nodes, nonbinary_ex_edges, NULL, nonbinary_ex_sites, nonbinary_ex_mutations, NULL, NULL, 0); - verify_pair_coalescence_counts(&ts, TSK_STAT_NODE); - verify_pair_coalescence_counts(&ts, TSK_STAT_NODE | TSK_STAT_SPAN_NORMALISE); verify_pair_coalescence_counts(&ts, 0); verify_pair_coalescence_counts(&ts, TSK_STAT_SPAN_NORMALISE); tsk_treeseq_free(&ts); diff --git a/c/tskit/core.c b/c/tskit/core.c index 1e989f015c..9ec99850f4 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -484,13 +484,15 @@ tsk_strerror_internal(int err) ret = "Insufficient weights provided (at least 1 required). " "(TSK_ERR_INSUFFICIENT_WEIGHTS)"; break; - case TSK_ERR_BAD_NODE_OUTPUT_MAP: - ret = "Node output map contains values less than TSK_NULL. " - "(TSK_ERR_BAD_NODE_OUTPUT_MAP)"; + + /* Pair coalescence errors */ + case TSK_ERR_BAD_NODE_BIN_MAP: + ret = "Node-to-bin map contains values less than TSK_NULL. " + "(TSK_ERR_BAD_NODE_BIN_MAP)"; break; - case TSK_ERR_BAD_NODE_OUTPUT_MAP_DIM: - ret = "Maximum index in node output map does not match " - "output dimension. (TSK_ERR_BAD_NODE_OUTPUT_MAP_DIM)"; + case TSK_ERR_BAD_NODE_BIN_MAP_DIM: + ret = "Maximum index in node-to-bin map does not match " + "output dimension. (TSK_ERR_BAD_NODE_BIN_MAP_DIM)"; break; /* Mutation mapping errors */ diff --git a/c/tskit/core.h b/c/tskit/core.h index 795d7ab4e6..ec1bbc818c 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -695,13 +695,13 @@ Insufficient weights were provided. */ #define TSK_ERR_INSUFFICIENT_WEIGHTS -913 /** -The node output map contains a value less than TSK_NULL. +The node bin map contains a value less than TSK_NULL. */ -#define TSK_ERR_BAD_NODE_OUTPUT_MAP -914 +#define TSK_ERR_BAD_NODE_BIN_MAP -914 /** -Maximum index in node output map does not match output dimension. +Maximum index in node bin map does not match output dimension. */ -#define TSK_ERR_BAD_NODE_OUTPUT_MAP_DIM -915 +#define TSK_ERR_BAD_NODE_BIN_MAP_DIM -915 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index e101f10a1a..eefd672656 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -8151,8 +8151,8 @@ tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, * ======================================================== */ static int -check_node_output_map(const tsk_size_t num_nodes, const tsk_size_t num_outputs, - const tsk_id_t *node_output_map) +check_node_bin_map( + const tsk_size_t num_nodes, const tsk_size_t num_bins, const tsk_id_t *node_bin_map) { int ret = 0; tsk_id_t max_index, index; @@ -8160,23 +8160,34 @@ check_node_output_map(const tsk_size_t num_nodes, const tsk_size_t num_outputs, max_index = TSK_NULL; for (i = 0; i < num_nodes; i++) { - index = node_output_map[i]; + index = node_bin_map[i]; if (index < TSK_NULL) { - ret = TSK_ERR_BAD_NODE_OUTPUT_MAP; + ret = TSK_ERR_BAD_NODE_BIN_MAP; goto out; } if (index > max_index) { max_index = index; } } - if (num_outputs < 1 || (tsk_id_t) num_outputs != max_index + 1) { - ret = TSK_ERR_BAD_NODE_OUTPUT_MAP_DIM; + if (num_bins < 1 || (tsk_id_t) num_bins != max_index + 1) { + ret = TSK_ERR_BAD_NODE_BIN_MAP_DIM; goto out; } out: return ret; } +static inline void +TRANSPOSE_2D(tsk_size_t rows, tsk_size_t cols, const double *source, double *dest) +{ + tsk_size_t i, j; + for (i = 0; i < rows; ++i) { + for (j = 0; j < cols; ++j) { + dest[j * rows + i] = source[i * cols + j]; + } + } +} + static inline void pair_coalescence_count(tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, tsk_size_t num_sample_sets, const double *parent_count, const double *child_count, @@ -8201,37 +8212,44 @@ int tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, tsk_size_t num_windows, - const double *windows, tsk_size_t num_outputs, const tsk_id_t *node_output_map, - tsk_flags_t options, double *result) + const double *windows, tsk_size_t num_bins, const tsk_id_t *node_bin_map, + pair_coalescence_stat_func_t *summary_func, tsk_size_t summary_func_dim, + void *summary_func_args, tsk_flags_t options, double *result) { int ret = 0; - double left, right, remaining_span, window_span; + double left, right, remaining_span, window_span, x, t; tsk_id_t e, p, c, u, v, w, i, j; tsk_size_t num_samples; tsk_tree_position_t tree_pos; const tsk_table_collection_t *tables = self->tables; const tsk_size_t num_nodes = tables->nodes.num_rows; + const double *restrict nodes_time = self->tables->nodes.time; const double sequence_length = tables->sequence_length; + const tsk_size_t num_outputs = summary_func_dim; /* buffers */ bool *visited = NULL; tsk_id_t *nodes_sample_set = NULL; tsk_id_t *nodes_parent = NULL; + double *coalescing_pairs = NULL; + double *coalescence_time = NULL; double *nodes_sample = NULL; double *sample_count = NULL; - double *coalescing_pairs = NULL; - double *nodes_weight = NULL; + double *bin_weight = NULL; + double *bin_values = NULL; + double *pair_count = NULL; double *outside = NULL; - double *count = NULL; /* row pointers */ double *inside = NULL; double *weight = NULL; + double *values = NULL; double *output = NULL; double *above = NULL; double *below = NULL; double *state = NULL; double *pairs = NULL; + double *times = NULL; tsk_memset(&tree_pos, 0, sizeof(tree_pos)); @@ -8249,7 +8267,7 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp if (ret != 0) { goto out; } - ret = check_node_output_map(num_nodes, num_outputs, node_output_map); + ret = check_node_bin_map(num_nodes, num_bins, node_bin_map); if (ret != 0) { goto out; } @@ -8266,22 +8284,24 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp goto out; } - /* initialize internal state */ visited = tsk_malloc(num_nodes * sizeof(*visited)); + outside = tsk_malloc(num_sample_sets * sizeof(*outside)); nodes_parent = tsk_malloc(num_nodes * sizeof(*nodes_parent)); nodes_sample = tsk_calloc(num_nodes * num_sample_sets, sizeof(*nodes_sample)); sample_count = tsk_malloc(num_nodes * num_sample_sets * sizeof(*sample_count)); - outside = tsk_malloc(num_sample_sets * sizeof(*outside)); - count = tsk_malloc(num_set_indexes * sizeof(*count)); - coalescing_pairs - = tsk_calloc(num_outputs * num_set_indexes, sizeof(*coalescing_pairs)); - nodes_weight = tsk_malloc(num_outputs * num_set_indexes * sizeof(*nodes_weight)); + coalescing_pairs = tsk_calloc(num_bins * num_set_indexes, sizeof(*coalescing_pairs)); + coalescence_time = tsk_calloc(num_bins * num_set_indexes, sizeof(*coalescence_time)); + bin_weight = tsk_malloc(num_bins * num_set_indexes * sizeof(*bin_weight)); + bin_values = tsk_malloc(num_bins * num_set_indexes * sizeof(*bin_values)); + pair_count = tsk_malloc(num_set_indexes * sizeof(*pair_count)); if (nodes_parent == NULL || nodes_sample == NULL || sample_count == NULL - || coalescing_pairs == NULL || nodes_weight == NULL || outside == NULL - || count == NULL || visited == NULL) { + || coalescing_pairs == NULL || bin_weight == NULL || bin_values == NULL + || outside == NULL || pair_count == NULL || visited == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } + + /* initialize internal state */ for (c = 0; c < (tsk_id_t) num_nodes; c++) { i = nodes_sample_set[c]; if (i != TSK_NULL) { @@ -8316,24 +8336,27 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp c = tables->edges.child[e]; nodes_parent[c] = TSK_NULL; inside = GET_2D_ROW(sample_count, num_sample_sets, c); - while (p != TSK_NULL) { - v = node_output_map[p]; + while (p != TSK_NULL) { /* downdate statistic */ + v = node_bin_map[p]; + t = nodes_time[p]; if (v != TSK_NULL) { above = GET_2D_ROW(sample_count, num_sample_sets, p); below = GET_2D_ROW(sample_count, num_sample_sets, c); state = GET_2D_ROW(nodes_sample, num_sample_sets, p); pairs = GET_2D_ROW(coalescing_pairs, num_set_indexes, v); pair_coalescence_count(num_set_indexes, set_indexes, num_sample_sets, - above, below, state, inside, outside, count); + above, below, state, inside, outside, pair_count); for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { - pairs[i] -= count[i] * remaining_span; + x = pair_count[i] * remaining_span; + pairs[i] -= x; + times[i] -= t * x; } } c = p; p = nodes_parent[c]; } p = tables->edges.parent[e]; - while (p != TSK_NULL) { + while (p != TSK_NULL) { /* downdate state */ above = GET_2D_ROW(sample_count, num_sample_sets, p); for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { above[i] -= inside[i]; @@ -8348,7 +8371,7 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp c = tables->edges.child[e]; nodes_parent[c] = p; inside = GET_2D_ROW(sample_count, num_sample_sets, c); - while (p != TSK_NULL) { + while (p != TSK_NULL) { /* update state */ above = GET_2D_ROW(sample_count, num_sample_sets, p); for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { above[i] += inside[i]; @@ -8356,17 +8379,21 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp p = nodes_parent[p]; } p = tables->edges.parent[e]; - while (p != TSK_NULL) { - v = node_output_map[p]; + while (p != TSK_NULL) { /* update statistic */ + v = node_bin_map[p]; + t = nodes_time[p]; if (v != TSK_NULL) { above = GET_2D_ROW(sample_count, num_sample_sets, p); below = GET_2D_ROW(sample_count, num_sample_sets, c); state = GET_2D_ROW(nodes_sample, num_sample_sets, p); pairs = GET_2D_ROW(coalescing_pairs, num_set_indexes, v); + times = GET_2D_ROW(coalescence_time, num_set_indexes, v); pair_coalescence_count(num_set_indexes, set_indexes, num_sample_sets, - above, below, state, inside, outside, count); + above, below, state, inside, outside, pair_count); for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { - pairs[i] += count[i] * remaining_span; + x = pair_count[i] * remaining_span; + pairs[i] += x; + times[i] += t * x; } } c = p; @@ -8376,27 +8403,36 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp /* flush windows */ while (w < (tsk_id_t) num_windows && windows[w + 1] <= right) { - remaining_span = sequence_length - windows[w + 1]; - tsk_memcpy(nodes_weight, coalescing_pairs, - num_outputs * num_set_indexes * sizeof(*nodes_weight)); + TRANSPOSE_2D(num_bins, num_set_indexes, coalescing_pairs, bin_weight); + TRANSPOSE_2D(num_bins, num_set_indexes, coalescence_time, bin_values); tsk_memset(coalescing_pairs, 0, - num_outputs * num_set_indexes * sizeof(*coalescing_pairs)); - for (j = 0; j < (tsk_id_t) num_samples; j++) { /* traverse subtree */ + num_bins * num_set_indexes * sizeof(*coalescing_pairs)); + tsk_memset(coalescence_time, 0, + num_bins * num_set_indexes * sizeof(*coalescence_time)); + remaining_span = sequence_length - windows[w + 1]; + for (j = 0; j < (tsk_id_t) num_samples; j++) { /* truncate at tree */ c = sample_sets[j]; p = nodes_parent[c]; while (!visited[c] && p != TSK_NULL) { - v = node_output_map[p]; + v = node_bin_map[p]; + t = nodes_time[p]; if (v != TSK_NULL) { above = GET_2D_ROW(sample_count, num_sample_sets, p); below = GET_2D_ROW(sample_count, num_sample_sets, c); state = GET_2D_ROW(nodes_sample, num_sample_sets, p); pairs = GET_2D_ROW(coalescing_pairs, num_set_indexes, v); - weight = GET_2D_ROW(nodes_weight, num_set_indexes, v); + times = GET_2D_ROW(coalescence_time, num_set_indexes, v); pair_coalescence_count(num_set_indexes, set_indexes, - num_sample_sets, above, below, state, below, outside, count); + num_sample_sets, above, below, state, below, outside, + pair_count); for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { - pairs[i] += count[i] * remaining_span / 2; - weight[i] -= count[i] * remaining_span / 2; + weight = GET_2D_ROW(bin_weight, num_bins, i); + values = GET_2D_ROW(bin_values, num_bins, i); + x = pair_count[i] * remaining_span / 2; + pairs[i] += x; + times[i] += t * x; + weight[v] -= x; + values[v] -= t * x; } } visited[c] = true; @@ -8404,7 +8440,7 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp p = nodes_parent[c]; } } - for (j = 0; j < (tsk_id_t) num_samples; j++) { + for (j = 0; j < (tsk_id_t) num_samples; j++) { /* reset tree */ c = sample_sets[j]; p = nodes_parent[c]; while (visited[c] && p != TSK_NULL) { @@ -8413,40 +8449,70 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp p = nodes_parent[c]; } } - if (options & TSK_STAT_SPAN_NORMALISE) { + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { /* normalise values */ + weight = GET_2D_ROW(bin_weight, num_bins, i); + values = GET_2D_ROW(bin_values, num_bins, i); + for (v = 0; v < (tsk_id_t) num_bins; v++) { + values[v] /= weight[v]; + } + } + if (options & TSK_STAT_SPAN_NORMALISE) { /* normalise weights */ window_span = windows[w + 1] - windows[w]; - for (v = 0; v < (tsk_id_t) num_outputs; v++) { - weight = GET_2D_ROW(nodes_weight, num_set_indexes, v); - for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { - weight[i] /= window_span; + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { + weight = GET_2D_ROW(bin_weight, num_bins, i); + for (v = 0; v < (tsk_id_t) num_bins; v++) { + weight[v] /= window_span; } } } - // TODO: - // for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { - // reduce(i, col, nodes_weight, &result[w * row_dim * col_dim]) - // }; - for (v = 0; v < (tsk_id_t) num_outputs; v++) { + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { /* summarise bins */ + weight = GET_2D_ROW(bin_weight, num_bins, i); + values = GET_2D_ROW(bin_values, num_bins, i); output = GET_3D_ROW( - result, num_outputs, num_set_indexes, (tsk_size_t) w, v); - weight = GET_2D_ROW(nodes_weight, num_set_indexes, v); - for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { - output[i] = weight[i]; + result, num_set_indexes, num_outputs, (tsk_size_t) w, i); + ret = summary_func( + num_bins, weight, values, num_outputs, output, summary_func_args); + if (ret != 0) { + goto out; } - } + }; w += 1; } } out: tsk_tree_position_free(&tree_pos); tsk_safe_free(nodes_sample_set); + tsk_safe_free(coalescing_pairs); + tsk_safe_free(coalescence_time); tsk_safe_free(nodes_parent); tsk_safe_free(nodes_sample); tsk_safe_free(sample_count); - tsk_safe_free(coalescing_pairs); - tsk_safe_free(nodes_weight); + tsk_safe_free(bin_weight); + tsk_safe_free(bin_values); + tsk_safe_free(pair_count); tsk_safe_free(visited); tsk_safe_free(outside); - tsk_safe_free(count); return ret; } + +static int +pair_coalescence_weights(tsk_size_t TSK_UNUSED(input_dim), const double *weight, + const double *TSK_UNUSED(values), tsk_size_t output_dim, double *output, + void *TSK_UNUSED(params)) +{ + int ret = 0; + tsk_memcpy(output, weight, output_dim * sizeof(*output)); + return ret; +} + +int +tsk_treeseq_pair_coalescence_counts(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, tsk_flags_t options, double *result) +{ + return tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_set_indexes, set_indexes, num_windows, windows, num_bins, + node_bin_map, pair_coalescence_weights, num_bins, NULL, options, result); +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index ed127797ce..697ae075ba 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1124,11 +1124,20 @@ int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samp tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); /* Coalescence rates */ +typedef int pair_coalescence_stat_func_t(tsk_size_t input_dim, const double *atoms, + const double *weights, tsk_size_t result_dim, double *result, void *params); int tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, - tsk_size_t num_windows, const double *windows, tsk_size_t num_outputs, - const tsk_id_t *node_output_map, tsk_flags_t options, double *result); + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, pair_coalescence_stat_func_t *summary_func, + tsk_size_t summary_func_dim, void *summary_func_args, tsk_flags_t options, + double *result); +int tsk_treeseq_pair_coalescence_counts(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, tsk_flags_t options, double *result); /****************************************************************************/ /* Tree */ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 59e1e76925..e7314327ff 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9861,30 +9861,30 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd } static int -parse_node_output_map(PyObject *node_output_map, PyArrayObject **ret_array, - tsk_size_t *ret_num_outputs, tsk_size_t num_nodes) +parse_node_bin_map(PyObject *node_bin_map, PyArrayObject **ret_array, + tsk_size_t *ret_num_bins, tsk_size_t num_nodes) { int ret = -1; - npy_int32 num_outputs = 0; - PyArrayObject *node_output_map_array = NULL; + npy_int32 num_bins = 0; + PyArrayObject *node_bin_map_array = NULL; npy_intp *shape; npy_int32 *data; npy_int32 max_index; tsk_size_t i; - node_output_map_array = (PyArrayObject *) PyArray_FROMANY( - node_output_map, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); - if (node_output_map_array == NULL) { + node_bin_map_array = (PyArrayObject *) PyArray_FROMANY( + node_bin_map, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (node_bin_map_array == NULL) { goto out; } - shape = PyArray_DIMS(node_output_map_array); + shape = PyArray_DIMS(node_bin_map_array); if ((tsk_size_t) shape[0] != num_nodes) { - PyErr_SetString(PyExc_ValueError, "Node output map must have a value per node"); + PyErr_SetString(PyExc_ValueError, "Node-to-bin map must have a value per node"); goto out; } max_index = TSK_NULL; - data = PyArray_DATA(node_output_map_array); + data = PyArray_DATA(node_bin_map_array); for (i = 0; i < num_nodes; i++) { if (data[i] > max_index) { max_index = data[i]; @@ -9892,14 +9892,14 @@ parse_node_output_map(PyObject *node_output_map, PyArrayObject **ret_array, } if (max_index == TSK_NULL) { PyErr_SetString( - PyExc_ValueError, "Node output map has null values for all nodes"); + PyExc_ValueError, "Node-to-bin map has null values for all nodes"); goto out; } - num_outputs = 1 + max_index; + num_bins = 1 + max_index; ret = 0; out: - *ret_num_outputs = (tsk_size_t) num_outputs; - *ret_array = node_output_map_array; + *ret_num_bins = (tsk_size_t) num_bins; + *ret_array = node_bin_map_array; return ret; } @@ -9937,15 +9937,15 @@ TreeSequence_pair_coalescence_counts(TreeSequence *self, PyObject *args, PyObjec PyObject *ret = NULL; static char *kwlist[] = { "windows", "sample_set_sizes", "sample_sets", "indexes", - "node_output_map", "span_normalise", NULL }; + "node_bin_map", "span_normalise", NULL }; PyObject *py_sample_set_sizes = Py_None; PyObject *py_sample_sets = Py_None; PyObject *py_windows = Py_None; - PyObject *py_node_output_map = Py_None; + PyObject *py_node_bin_map = Py_None; PyObject *py_indexes = Py_None; PyArrayObject *result_array = NULL; PyArrayObject *windows_array = NULL; - PyArrayObject *node_output_map_array = NULL; + PyArrayObject *node_bin_map_array = NULL; PyArrayObject *indexes_array = NULL; PyArrayObject *sample_set_sizes_array = NULL; PyArrayObject *sample_sets_array = NULL; @@ -9954,7 +9954,7 @@ TreeSequence_pair_coalescence_counts(TreeSequence *self, PyObject *args, PyObjec tsk_size_t num_indexes = 0; tsk_size_t num_sample_sets = 0; tsk_size_t num_windows = 0; - tsk_size_t num_outputs = 0; + tsk_size_t num_bins = 0; int span_normalise = 0; int err; @@ -9962,7 +9962,7 @@ TreeSequence_pair_coalescence_counts(TreeSequence *self, PyObject *args, PyObjec goto out; } if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOO|i", kwlist, &py_windows, - &py_sample_set_sizes, &py_sample_sets, &py_indexes, &py_node_output_map, + &py_sample_set_sizes, &py_sample_sets, &py_indexes, &py_node_bin_map, &span_normalise)) { goto out; } @@ -9977,7 +9977,7 @@ TreeSequence_pair_coalescence_counts(TreeSequence *self, PyObject *args, PyObjec if (parse_set_indexes(py_indexes, &indexes_array, &num_indexes, 2) != 0) { goto out; } - if (parse_node_output_map(py_node_output_map, &node_output_map_array, &num_outputs, + if (parse_node_bin_map(py_node_bin_map, &node_bin_map_array, &num_bins, tsk_treeseq_get_num_nodes(self->tree_sequence)) != 0) { goto out; @@ -9987,18 +9987,18 @@ TreeSequence_pair_coalescence_counts(TreeSequence *self, PyObject *args, PyObjec } dims[0] = (npy_intp) num_windows; - dims[1] = (npy_intp) num_outputs; - dims[2] = (npy_intp) num_indexes; + dims[1] = (npy_intp) num_indexes; + dims[2] = (npy_intp) num_bins; result_array = (PyArrayObject *) PyArray_SimpleNew(3, dims, NPY_FLOAT64); if (result_array == NULL) { goto out; } - err = tsk_treeseq_pair_coalescence_stat(self->tree_sequence, num_sample_sets, + err = tsk_treeseq_pair_coalescence_counts(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), num_indexes, PyArray_DATA(indexes_array), num_windows, - PyArray_DATA(windows_array), num_outputs, PyArray_DATA(node_output_map_array), - options, PyArray_DATA(result_array)); + PyArray_DATA(windows_array), num_bins, PyArray_DATA(node_bin_map_array), options, + PyArray_DATA(result_array)); if (err != 0) { handle_library_error(err); goto out; @@ -10010,7 +10010,7 @@ TreeSequence_pair_coalescence_counts(TreeSequence *self, PyObject *args, PyObjec Py_XDECREF(sample_sets_array); Py_XDECREF(windows_array); Py_XDECREF(indexes_array); - Py_XDECREF(node_output_map_array); + Py_XDECREF(node_bin_map_array); Py_XDECREF(result_array); return ret; } diff --git a/python/requirements/CI-complete/requirements.txt b/python/requirements/CI-complete/requirements.txt index 6cb9c3571a..b3cbe9cf62 100644 --- a/python/requirements/CI-complete/requirements.txt +++ b/python/requirements/CI-complete/requirements.txt @@ -1,7 +1,7 @@ biopython==1.83 coverage==7.5.4 dendropy==5.0.1 -h5py==3.9.0 +h5py==3.11.0 kastore==0.3.3 lshmm==0.0.8 msgpack==1.0.8 diff --git a/python/tests/test_coalrate.py b/python/tests/test_coalrate.py index 37ea4b0d28..4cf326401b 100644 --- a/python/tests/test_coalrate.py +++ b/python/tests/test_coalrate.py @@ -103,9 +103,9 @@ def _pair_coalescence_stat( if not (min(s) >= 0 and max(s) < ts.num_nodes): raise ValueError("Sample is out of bounds") - drop_right_dimension = False + drop_middle_dimension = False if indexes is None: - drop_right_dimension = True + drop_middle_dimension = True if len(sample_sets) == 1: indexes = [(0, 0)] elif len(sample_sets) == 2: @@ -255,10 +255,11 @@ def _pair_coalescence_stat( ) w += 1 - if drop_right_dimension: - output = output[..., 0] + output = output.transpose(0, 2, 1) + if drop_middle_dimension: + output = output.squeeze(1) if drop_left_dimension: - output = output[0] + output = output.squeeze(0) return output @@ -293,7 +294,7 @@ def proto_pair_coalescence_counts( events within time intervals (if an array of breakpoints is supplied) rather than for individual nodes (the default). - The output array has dimension `(windows, nodes, indexes)` with + The output array has dimension `(windows, indexes, nodes)` with dimensions dropped when the corresponding argument is set to None. :param list sample_sets: A list of lists of Node IDs, specifying the @@ -558,9 +559,9 @@ def test_population_pairs(self): indexes = [(0, 0), (0, 1), (1, 1)] implm = ts.pair_coalescence_counts(sample_sets=[ss0, ss1], indexes=indexes) check = np.full(implm.shape, np.nan) - check[:, 0] = np.array([0.0] * 8 + [0, 0, 1, 5, 0, 0]) - check[:, 1] = np.array([0.0] * 8 + [0, 0, 0, 0, 4, 12]) - check[:, 2] = np.array([0.0] * 8 + [1, 2, 0, 0, 0, 3]) + check[0] = np.array([0.0] * 8 + [0, 0, 1, 5, 0, 0]) + check[1] = np.array([0.0] * 8 + [0, 0, 0, 0, 4, 12]) + check[2] = np.array([0.0] * 8 + [1, 2, 0, 0, 0, 3]) np.testing.assert_allclose(implm, check) # TODO: remove with prototype proto = proto_pair_coalescence_counts( @@ -715,9 +716,9 @@ def test_population_pairs(self): sample_sets=[[0, 1], [2, 3]], indexes=indexes, span_normalise=False ) check = np.empty(implm.shape) - check[:, 0] = np.array([0] * 4 + [0, 0, 0, 1 * L]) - check[:, 1] = np.array([0] * 4 + [1 * (L - S), 1 * (L - S), 2 * S, 2 * L]) - check[:, 2] = np.array([0] * 4 + [0, 1 * L, 0, 0]) + check[0] = np.array([0] * 4 + [0, 0, 0, 1 * L]) + check[1] = np.array([0] * 4 + [1 * (L - S), 1 * (L - S), 2 * S, 2 * L]) + check[2] = np.array([0] * 4 + [0, 1 * L, 0, 0]) np.testing.assert_allclose(implm, check) # TODO: remove with prototype proto = proto_pair_coalescence_counts( @@ -864,13 +865,13 @@ def _check_subset_pairs(ts, windows): implm = ts.pair_coalescence_counts( sample_sets=[ss0, ss1], indexes=idx, windows=windows, span_normalise=False ) - dim = (windows.size - 1, ts.num_nodes, len(idx)) + dim = (windows.size - 1, len(idx), ts.num_nodes) check = np.full(dim, np.nan) for w, (a, b) in enumerate(zip(windows[:-1], windows[1:])): tsw = ts.keep_intervals(np.array([[a, b]]), simplify=False) - check[w, :, 0] = naive_pair_coalescence_counts(tsw, ss0, ss1) - check[w, :, 1] = naive_pair_coalescence_counts(tsw, ss1, ss1) / 2 - check[w, :, 2] = naive_pair_coalescence_counts(tsw, ss0, ss0) / 2 + check[w, 0] = naive_pair_coalescence_counts(tsw, ss0, ss1) + check[w, 1] = naive_pair_coalescence_counts(tsw, ss1, ss1) / 2 + check[w, 2] = naive_pair_coalescence_counts(tsw, ss0, ss0) / 2 np.testing.assert_allclose(implm, check) # TODO: remove with prototype proto = proto_pair_coalescence_counts( @@ -1258,12 +1259,12 @@ def test_output_dim(self): sample_sets=ss, windows=windows, indexes=None ) assert implm.shape == (1, ts.num_nodes) - indexes = [(0, 1)] + indexes = [(0, 1), (1, 1)] implm = ts.pair_coalescence_counts( sample_sets=ss, windows=windows, indexes=indexes ) - assert implm.shape == (1, ts.num_nodes, 1) + assert implm.shape == (1, 2, ts.num_nodes) implm = ts.pair_coalescence_counts( sample_sets=ss, windows=None, indexes=indexes ) - assert implm.shape == (ts.num_nodes, 1) + assert implm.shape == (2, ts.num_nodes) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 2700aa4f2c..8662b669bf 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -4224,7 +4224,7 @@ def pair_coalescence_counts( sample_set_sizes=None, indexes=None, windows=None, - node_output_map=None, + node_bin_map=None, span_normalise=False, ): n = ts.get_num_samples() @@ -4240,25 +4240,34 @@ def pair_coalescence_counts( indexes = [(i, j) for i, j in pairs] if windows is None: windows = np.array([0, 0.5, 1.0]) * ts.get_sequence_length() - if node_output_map is None: - node_output_map = np.arange(N, dtype=np.int32) + if node_bin_map is None: + node_bin_map = np.arange(N, dtype=np.int32) return ts.pair_coalescence_counts( sample_sets=sample_sets, sample_set_sizes=sample_set_sizes, windows=windows, indexes=indexes, - node_output_map=node_output_map, + node_bin_map=node_bin_map, span_normalise=span_normalise, ) def test_output_dims(self): ts = self.example_ts() coal = self.pair_coalescence_counts(ts) - dim = (2, ts.get_num_nodes(), 3) + dim = (2, 3, ts.get_num_nodes()) assert coal.shape == dim coal = self.pair_coalescence_counts(ts, span_normalise=True) assert coal.shape == dim + def test_node_shuffle(self): + rng = np.random.default_rng(1024) + ts = self.example_ts() + coal = self.pair_coalescence_counts(ts) + node_bin_map = np.arange(ts.get_num_nodes(), dtype=np.int32) + rng.shuffle(node_bin_map) + coal_shuffle = self.pair_coalescence_counts(ts, node_bin_map=node_bin_map) + np.testing.assert_allclose(coal_shuffle[..., node_bin_map], coal) + @pytest.mark.parametrize("bad_node", [-1, -2, 1000]) def test_c_tsk_err_node_out_of_bounds(self, bad_node): ts = self.example_ts() @@ -4274,12 +4283,12 @@ def test_c_tsk_err_bad_windows(self): with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): self.pair_coalescence_counts(ts, windows=[-1.0, L]) - def test_c_tsk_err_bad_node_output_map(self): + def test_c_tsk_err_bad_node_bin_map(self): ts = self.example_ts() - node_output_map = np.arange(ts.get_num_nodes(), dtype=np.int32) - node_output_map[0] = -10 - with pytest.raises(_tskit.LibraryError, match="BAD_NODE_OUTPUT_MAP"): - self.pair_coalescence_counts(ts, node_output_map=node_output_map) + node_bin_map = np.arange(ts.get_num_nodes(), dtype=np.int32) + node_bin_map[0] = -10 + with pytest.raises(_tskit.LibraryError, match="BAD_NODE_BIN_MAP"): + self.pair_coalescence_counts(ts, node_bin_map=node_bin_map) @pytest.mark.parametrize("bad_index", [-1, 10]) def test_c_tsk_err_bad_sample_set_index(self, bad_index): @@ -4316,13 +4325,13 @@ def test_cpy_bad_indexes(self, indexes): with pytest.raises(ValueError, match="too small depth"): self.pair_coalescence_counts(ts, indexes=np.ravel(indexes)) - def test_cpy_bad_node_output_map(self): + def test_cpy_bad_node_bin_map(self): ts = self.example_ts() num_nodes = ts.get_num_nodes() - node_output_map = np.full(num_nodes, tskit.NULL, dtype=np.int32) + node_bin_map = np.full(num_nodes, tskit.NULL, dtype=np.int32) with pytest.raises(ValueError, match="null values for all nodes"): - self.pair_coalescence_counts(ts, node_output_map=node_output_map) + self.pair_coalescence_counts(ts, node_bin_map=node_bin_map) with pytest.raises(ValueError, match="a value per node"): - self.pair_coalescence_counts(ts, node_output_map=node_output_map[:-1]) + self.pair_coalescence_counts(ts, node_bin_map=node_bin_map[:-1]) with pytest.raises(TypeError, match="cast array data"): - self.pair_coalescence_counts(ts, node_output_map=np.zeros(num_nodes)) + self.pair_coalescence_counts(ts, node_bin_map=np.zeros(num_nodes)) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 901e6736d4..4d6b9da347 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -9328,7 +9328,7 @@ def pair_coalescence_counts( events within time intervals (if an array of breakpoints is supplied) rather than for individual nodes (the default). - The output array has dimension `(windows, nodes, indexes)` with + The output array has dimension `(windows, indexes, nodes)` with dimensions dropped when the corresponding argument is set to None. :param list sample_sets: A list of lists of Node IDs, specifying the @@ -9350,9 +9350,9 @@ def pair_coalescence_counts( if not (min(s) >= 0 and max(s) < self.num_nodes): raise ValueError("Sample is out of bounds") - drop_right_dimension = False + drop_middle_dimension = False if indexes is None: - drop_right_dimension = True + drop_middle_dimension = True if len(sample_sets) == 1: indexes = [(0, 0)] elif len(sample_sets) == 2: @@ -9379,7 +9379,7 @@ def pair_coalescence_counts( raise ValueError("Window breaks must be strictly increasing") if isinstance(time_windows, str) and time_windows == "nodes": - node_output_map = np.arange(self.num_nodes, dtype=np.int32) + node_bin_map = np.arange(self.num_nodes, dtype=np.int32) else: if not (isinstance(time_windows, np.ndarray) and time_windows.size > 1): raise ValueError("Time windows must be an array of breakpoints") @@ -9387,9 +9387,9 @@ def pair_coalescence_counts( raise ValueError("Time windows must be strictly increasing") if self.time_units == tskit.TIME_UNITS_UNCALIBRATED: raise ValueError("Time windows require calibrated node times") - node_output_map = np.digitize(self.nodes_time, time_windows) - 1 - node_output_map[node_output_map == time_windows.size - 1] = tskit.NULL - node_output_map = node_output_map.astype(np.int32) + node_bin_map = np.digitize(self.nodes_time, time_windows) - 1 + node_bin_map[node_bin_map == time_windows.size - 1] = tskit.NULL + node_bin_map = node_bin_map.astype(np.int32) sample_set_sizes = np.array([len(s) for s in sample_sets], dtype=np.uint32) sample_sets = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) @@ -9399,14 +9399,14 @@ def pair_coalescence_counts( sample_set_sizes=sample_set_sizes, windows=windows, indexes=indexes, - node_output_map=node_output_map, + node_bin_map=node_bin_map, span_normalise=span_normalise, ) - if drop_right_dimension: - coalescing_pairs = coalescing_pairs[..., 0] + if drop_middle_dimension: + coalescing_pairs = np.squeeze(coalescing_pairs, axis=1) if drop_left_dimension: - coalescing_pairs = coalescing_pairs[0] + coalescing_pairs = np.squeeze(coalescing_pairs, axis=0) return coalescing_pairs