diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 7dccc2346d..a40817e7a9 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2017,11 +2017,17 @@ test_empty_genetic_relatedness_vector(void) } ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 1, windows, result, 0); + &ts, num_weights, weights, 1, windows, num_samples, ts.samples, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 1, windows, + num_samples, ts.samples, result, TSK_STAT_NONCENTRED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + windows[0] = 0.5 * tsk_treeseq_get_sequence_length(&ts); + windows[1] = 0.75 * tsk_treeseq_get_sequence_length(&ts); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 1, windows, result, TSK_STAT_NONCENTRED); + &ts, num_weights, weights, 1, windows, num_samples, ts.samples, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_treeseq_free(&ts); @@ -2055,17 +2061,26 @@ verify_genetic_relatedness_vector( } } - ret = tsk_treeseq_genetic_relatedness_vector( - ts, num_weights, weights, num_windows, windows, result, 0); + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_genetic_relatedness_vector( - ts, num_weights, weights, num_windows, windows, result, TSK_STAT_NONCENTRED); + windows[0] = windows[1] / 2; + if (num_windows > 1) { + windows[num_windows - 1] + = windows[num_windows - 2] + (L / (double) (2 * num_windows)); + } + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, TSK_STAT_NONCENTRED); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_set_debug_stream(_devnull); - ret = tsk_treeseq_genetic_relatedness_vector( - ts, num_weights, weights, num_windows, windows, result, TSK_DEBUG); + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, TSK_DEBUG); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_set_debug_stream(stdout); @@ -2102,6 +2117,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void) tsk_size_t num_samples; double *weights, *result; tsk_size_t j; + tsk_size_t num_windows = 2; tsk_size_t num_weights = 2; double windows[] = { 0, 0, 0 }; @@ -2110,7 +2126,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void) num_samples = tsk_treeseq_get_num_samples(&ts); weights = tsk_malloc(num_weights * num_samples * sizeof(double)); - result = tsk_malloc(num_weights * num_samples * sizeof(double)); + result = tsk_malloc(num_windows * num_weights * num_samples * sizeof(double)); for (j = 0; j < num_samples; j++) { weights[j] = 1.0; } @@ -2120,41 +2136,41 @@ test_paper_ex_genetic_relatedness_vector_errors(void) /* Window errors */ ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 0, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 0, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 0, NULL, result, TSK_STAT_BRANCH); + &ts, 1, weights, 0, NULL, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); windows[0] = -1; ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); - windows[0] = 10; + windows[0] = 12; ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); windows[0] = 0; windows[2] = 12; ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); /* unsupported mode errors */ windows[0] = 0.0; windows[1] = 5.0; windows[2] = 10.0; - ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 2, windows, result, TSK_STAT_SITE); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_samples, ts.samples, result, TSK_STAT_SITE); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); - ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 2, windows, result, TSK_STAT_NODE); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_samples, ts.samples, result, TSK_STAT_NODE); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); tsk_treeseq_free(&ts); @@ -2162,6 +2178,51 @@ test_paper_ex_genetic_relatedness_vector_errors(void) free(result); } +static void +test_paper_ex_genetic_relatedness_vector_node_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_size_t num_samples; + double *weights, *result; + tsk_size_t j; + tsk_size_t num_weights = 2; + tsk_size_t num_windows = 2; + double windows[] = { 1, 1.5, 2 }; + tsk_size_t num_nodes = 3; + const tsk_id_t good_nodes[] = { 1, 0, 2 }; + const tsk_id_t bad_nodes1[] = { 1, -1, 2 }; + const tsk_id_t bad_nodes2[] = { 1, 100, 2 }; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + num_samples = tsk_treeseq_get_num_samples(&ts); + + weights = tsk_malloc(num_weights * num_samples * sizeof(double)); + result = tsk_malloc(num_windows * num_weights * num_nodes * sizeof(double)); + for (j = 0; j < num_samples; j++) { + weights[j] = 1.0; + } + for (j = 0; j < num_samples; j++) { + weights[j + num_samples] = (float) j; + } + + /* node errors */ + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_nodes, good_nodes, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_nodes, bad_nodes1, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_nodes, bad_nodes2, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_treeseq_free(&ts); + free(weights); + free(result); +} + static void test_paper_ex_Y2_errors(void) { @@ -3708,6 +3769,8 @@ main(int argc, char **argv) test_paper_ex_genetic_relatedness_vector }, { "test_paper_ex_genetic_relatedness_vector_errors", test_paper_ex_genetic_relatedness_vector_errors }, + { "test_paper_ex_genetic_relatedness_vector_node_errors", + test_paper_ex_genetic_relatedness_vector_node_errors }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 9462dfc51c..a554cfad55 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -5197,6 +5197,7 @@ tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out) { fprintf(out, "Tree position state\n"); fprintf(out, "index = %d\n", (int) self->index); + fprintf(out, "interval = [%f,\t%f)\n", self->interval.left, self->interval.right); fprintf( out, "out = start=%d\tstop=%d\n", (int) self->out.start, (int) self->out.stop); fprintf( @@ -9907,10 +9908,12 @@ typedef struct { const double *weights; tsk_size_t num_windows; const double *windows; + tsk_size_t num_focal_nodes; + const tsk_id_t *focal_nodes; tsk_flags_t options; double *result; - /* tree */ - double tree_left; + tsk_tree_position_t tree_pos; + double position; tsk_size_t num_nodes; tsk_id_t *parent; double *x; @@ -9926,7 +9929,10 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out fprintf(out, "Matvec state:\n"); fprintf(out, "options = %d\n", self->options); - fprintf(out, "tree_left = %f\n", self->tree_left); + fprintf(out, "position = %f\n", self->position); + fprintf(out, "focal nodes = %lld: [", (long long) self->num_focal_nodes); + fprintf(out, "tree_pos:\n"); + tsk_tree_position_print_state(&self->tree_pos, out); fprintf(out, "samples = %lld: [", (long long) num_samples); fprintf(out, "]\n"); fprintf(out, "node\tparent\tx\tv\tw"); @@ -9942,7 +9948,8 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out static int tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *ts, tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) + const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, + tsk_flags_t options, double *result) { int ret = 0; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); @@ -9950,18 +9957,22 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t const double *row; double *new_row; tsk_size_t k; - tsk_id_t u, j; + tsk_id_t index, u, j; double *weight_means = tsk_malloc(num_weights * sizeof(*weight_means)); + const tsk_size_t num_trees = ts->num_trees; + const double *restrict breakpoints = ts->breakpoints; self->ts = ts; - self->tree_left = 0.0; self->num_weights = num_weights; self->weights = weights; self->num_windows = num_windows; self->windows = windows; + self->num_focal_nodes = num_focal_nodes; + self->focal_nodes = focal_nodes; self->options = options; self->result = result; self->num_nodes = num_nodes; + self->position = windows[0]; self->parent = tsk_malloc(num_nodes * sizeof(*self->parent)); self->x = tsk_calloc(num_nodes, sizeof(*self->x)); @@ -9974,12 +9985,34 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t goto out; } - tsk_memset(result, 0, num_windows * num_samples * num_weights * sizeof(*result)); + tsk_memset(result, 0, num_windows * num_focal_nodes * num_weights * sizeof(*result)); tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); + for (j = 0; j < (tsk_id_t) num_focal_nodes; j++) { + if (focal_nodes[j] < 0 || (tsk_size_t) focal_nodes[j] >= num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } + + ret = tsk_tree_position_init(&self->tree_pos, ts, 0); + if (ret != 0) { + goto out; + } + /* seek to the first window */ + index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, windows[0]); + if (breakpoints[index] > windows[0]) { + index--; + } + ret = tsk_tree_position_seek_forward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + for (k = 0; k < num_weights; k++) { weight_means[k] = 0.0; } + /* centre the input */ if (!(options & TSK_STAT_NONCENTRED)) { for (j = 0; j < (tsk_id_t) num_samples; j++) { row = GET_2D_ROW(weights, num_weights, j); @@ -9992,6 +10025,7 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t } } + /* set the initial state */ for (j = 0; j < (tsk_id_t) num_samples; j++) { u = ts->samples[j]; row = GET_2D_ROW(weights, num_weights, j); @@ -10012,6 +10046,7 @@ tsk_matvec_calculator_free(tsk_matvec_calculator_t *self) tsk_safe_free(self->x); tsk_safe_free(self->w); tsk_safe_free(self->v); + tsk_tree_position_free(&self->tree_pos); /* Make this safe for multiple free calls */ memset(self, 0, sizeof(*self)); @@ -10019,7 +10054,7 @@ tsk_matvec_calculator_free(tsk_matvec_calculator_t *self) } static inline void -tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double tree_left, +tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double position, double *restrict x, const tsk_size_t num_weights, double *restrict w, double *restrict v, const double *restrict nodes_time) { @@ -10029,7 +10064,7 @@ tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double tree_left, if (p != TSK_NULL) { t = nodes_time[p] - nodes_time[u]; - span = tree_left - x[u]; + span = position - x[u]; // do this: self->v[u] += t * span * self->w[u]; w_row = GET_2D_ROW(w, num_weights, u); v_row = GET_2D_ROW(v, num_weights, u); @@ -10037,7 +10072,7 @@ tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double tree_left, v_row[j] += t * span * w_row[j]; } } - x[u] = tree_left; + x[u] = position; } static void @@ -10047,7 +10082,7 @@ tsk_matvec_calculator_adjust_path_up( tsk_size_t j; double *p_row, *c_row; const tsk_id_t *restrict parent = self->parent; - const double tree_left = self->tree_left; + const double position = self->position; double *restrict x = self->x; const tsk_size_t num_weights = self->num_weights; double *restrict w = self->w; @@ -10057,7 +10092,7 @@ tsk_matvec_calculator_adjust_path_up( // sign = -1 for removing edges, +1 for adding while (p != TSK_NULL) { tsk_matvec_calculator_add_z( - p, parent[p], tree_left, x, num_weights, w, v, nodes_time); + p, parent[p], position, x, num_weights, w, v, nodes_time); // do this: self->v[c] -= sign * self->v[p]; p_row = GET_2D_ROW(v, num_weights, p); c_row = GET_2D_ROW(v, num_weights, c); @@ -10078,7 +10113,7 @@ static void tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c) { tsk_id_t *parent = self->parent; - const double tree_left = self->tree_left; + const double position = self->position; double *restrict x = self->x; const tsk_size_t num_weights = self->num_weights; double *restrict w = self->w; @@ -10086,7 +10121,7 @@ tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk const double *restrict nodes_time = self->ts->tables->nodes.time; tsk_matvec_calculator_add_z( - c, parent[c], tree_left, x, num_weights, w, v, nodes_time); + c, parent[c], position, x, num_weights, w, v, nodes_time); parent[c] = TSK_NULL; tsk_matvec_calculator_adjust_path_up(self, p, c, -1); } @@ -10097,7 +10132,7 @@ tsk_matvec_calculator_insert_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk tsk_id_t *parent = self->parent; tsk_matvec_calculator_adjust_path_up(self, p, c, +1); - self->x[c] = self->tree_left; + self->x[c] = self->position; parent[c] = p; } @@ -10107,9 +10142,9 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri int ret = 0; tsk_id_t u; tsk_size_t j, k; - tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); + const tsk_size_t n = self->num_focal_nodes; const tsk_size_t num_weights = self->num_weights; - const double tree_left = self->tree_left; + const double position = self->position; double *u_row, *out_row; double *out_means = tsk_malloc(num_weights * sizeof(*out_means)); const tsk_id_t *restrict parent = self->parent; @@ -10117,7 +10152,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri double *restrict x = self->x; double *restrict w = self->w; double *restrict v = self->v; - const tsk_id_t *restrict samples = self->ts->samples; + const tsk_id_t *restrict focal_nodes = self->focal_nodes; if (out_means == NULL) { ret = TSK_ERR_NO_MEMORY; @@ -10126,11 +10161,11 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri for (j = 0; j < n; j++) { out_row = GET_2D_ROW(y, num_weights, j); - u = samples[j]; + u = focal_nodes[j]; while (u != TSK_NULL) { - if (x[u] != tree_left) { + if (x[u] != position) { tsk_matvec_calculator_add_z( - u, parent[u], tree_left, x, num_weights, w, v, nodes_time); + u, parent[u], position, x, num_weights, w, v, nodes_time); } u_row = GET_2D_ROW(v, num_weights, u); for (k = 0; k < num_weights; k++) { @@ -10173,50 +10208,65 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) int ret = 0; tsk_size_t j, k, m; tsk_id_t e, p, c; - tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); - double tree_right; - const double sequence_length = self->ts->tables->sequence_length; + const tsk_size_t out_size = self->num_weights * self->num_focal_nodes; const tsk_size_t num_edges = self->ts->tables->edges.num_rows; - const tsk_id_t *restrict I = self->ts->tables->indexes.edge_insertion_order; - const tsk_id_t *restrict O = self->ts->tables->indexes.edge_removal_order; const double *restrict edge_right = self->ts->tables->edges.right; const double *restrict edge_left = self->ts->tables->edges.left; const tsk_id_t *restrict edge_child = self->ts->tables->edges.child; const tsk_id_t *restrict edge_parent = self->ts->tables->edges.parent; + const double *restrict windows = self->windows; double *restrict out; + tsk_tree_position_t tree_pos = self->tree_pos; + const tsk_id_t *restrict in_order = tree_pos.in.order; + const tsk_id_t *restrict out_order = tree_pos.out.order; + bool valid; + double next_position; - j = 0; - k = 0; m = 0; - tree_right = sequence_length; + self->position = windows[0]; - while ( - m < self->num_windows && k < num_edges && self->tree_left <= sequence_length) { - while (k < num_edges && edge_right[O[k]] == self->tree_left) { - e = O[k]; - p = edge_parent[e]; - c = edge_child[e]; - tsk_matvec_calculator_remove_edge(self, p, c); - k++; - } - while (j < num_edges && edge_left[I[j]] == self->tree_left) { - e = I[j]; + for (j = (tsk_size_t) tree_pos.in.start; j != (tsk_size_t) tree_pos.in.stop; j++) { + e = in_order[j]; + tsk_bug_assert(edge_left[e] <= self->position); + if (self->position < edge_right[e]) { p = edge_parent[e]; c = edge_child[e]; tsk_matvec_calculator_insert_edge(self, p, c); - self->x[c] = self->tree_left; - j++; } - tree_right = self->windows[m + 1]; + } + + valid = tsk_tree_position_next(&tree_pos); + j = (tsk_size_t) tree_pos.in.start; + k = (tsk_size_t) tree_pos.out.start; + while (m < self->num_windows) { + if (valid && self->position == tree_pos.interval.left) { + for (k = (tsk_size_t) tree_pos.out.start; + k != (tsk_size_t) tree_pos.out.stop; k++) { + e = out_order[k]; + p = edge_parent[e]; + c = edge_child[e]; + tsk_matvec_calculator_remove_edge(self, p, c); + } + for (j = (tsk_size_t) tree_pos.in.start; j != (tsk_size_t) tree_pos.in.stop; + j++) { + e = in_order[j]; + p = edge_parent[e]; + c = edge_child[e]; + tsk_matvec_calculator_insert_edge(self, p, c); + } + valid = tsk_tree_position_next(&tree_pos); + } + next_position = windows[m + 1]; if (j < num_edges) { - tree_right = TSK_MIN(tree_right, edge_left[I[j]]); + next_position = TSK_MIN(next_position, edge_left[in_order[j]]); } if (k < num_edges) { - tree_right = TSK_MIN(tree_right, edge_right[O[k]]); + next_position = TSK_MIN(next_position, edge_right[out_order[k]]); } - self->tree_left = tree_right; - if (self->tree_left == self->windows[m + 1]) { - out = GET_2D_ROW(self->result, self->num_weights * n, m); + tsk_bug_assert(self->position < next_position); + self->position = next_position; + if (self->position == windows[m + 1]) { + out = GET_2D_ROW(self->result, out_size, m); tsk_matvec_calculator_write_output(self, out); m += 1; } @@ -10231,7 +10281,8 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights, - const double *weights, tsk_size_t num_windows, const double *windows, double *result, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result, tsk_flags_t options) { int ret = 0; @@ -10245,13 +10296,13 @@ tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num ret = TSK_ERR_UNSUPPORTED_STAT_MODE; goto out; } - ret = tsk_treeseq_check_windows(self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); + ret = tsk_treeseq_check_windows(self, num_windows, windows, 0); if (ret != 0) { goto out; } - ret = tsk_matvec_calculator_init( - &calc, self, num_weights, weights, num_windows, windows, options, result); + ret = tsk_matvec_calculator_init(&calc, self, num_weights, weights, num_windows, + windows, num_focal_nodes, focal_nodes, options, result); if (ret != 0) { goto out; } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index df9cf92850..bef944fff3 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1028,12 +1028,14 @@ int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, /* One way weighted stats with vector output */ typedef int weighted_vector_method(const tsk_treeseq_t *self, tsk_size_t num_weights, - const double *weights, tsk_size_t num_windows, const double *windows, double *result, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result, tsk_flags_t options); int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, - const double *windows, double *result, tsk_flags_t options); + const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, + double *result, tsk_flags_t options); /* One way sample set stats */ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 6d275a499a..8663bd8695 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9652,14 +9652,17 @@ TreeSequence_weighted_stat_vector_method( { PyObject *ret = NULL; static char *kwlist[] - = { "weights", "windows", "mode", "span_normalise", "centre", NULL }; + = { "weights", "windows", "mode", "span_normalise", "centre", "nodes", NULL }; PyObject *weights = NULL; PyObject *windows = NULL; + PyObject *focal_nodes = NULL; PyArrayObject *weights_array = NULL; PyArrayObject *windows_array = NULL; PyArrayObject *result_array = NULL; + PyArrayObject *focal_nodes_array = NULL; tsk_size_t num_windows; - npy_intp *w_shape, result_shape[3]; + tsk_size_t num_focal_nodes; + npy_intp *focal_nodes_shape, *w_shape, result_shape[3]; tsk_flags_t options = 0; tsk_size_t num_samples; char *mode = NULL; @@ -9670,8 +9673,8 @@ TreeSequence_weighted_stat_vector_method( if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|sii", kwlist, &weights, &windows, - &mode, &span_normalise, ¢re)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|siiO", kwlist, &weights, &windows, + &mode, &span_normalise, ¢re, &focal_nodes)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -9697,16 +9700,24 @@ TreeSequence_weighted_stat_vector_method( PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); goto out; } + focal_nodes_array = (PyArrayObject *) PyArray_FROMANY( + focal_nodes, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (focal_nodes_array == NULL) { + goto out; + } + focal_nodes_shape = PyArray_DIMS(focal_nodes_array); + num_focal_nodes = focal_nodes_shape[0]; result_shape[0] = num_windows; - result_shape[1] = num_samples; + result_shape[1] = num_focal_nodes; result_shape[2] = w_shape[1]; result_array = (PyArrayObject *) PyArray_SimpleNew(3, result_shape, NPY_FLOAT64); if (result_array == NULL) { goto out; } err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), - num_windows, PyArray_DATA(windows_array), PyArray_DATA(result_array), options); + num_windows, PyArray_DATA(windows_array), num_focal_nodes, + PyArray_DATA(focal_nodes_array), PyArray_DATA(result_array), options); if (err != 0) { handle_library_error(err); goto out; @@ -9716,6 +9727,7 @@ TreeSequence_weighted_stat_vector_method( out: Py_XDECREF(weights_array); Py_XDECREF(windows_array); + Py_XDECREF(focal_nodes_array); Py_XDECREF(result_array); return ret; } diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 1385adc283..c07f975f83 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2682,6 +2682,7 @@ def get_example(self, num_weights=2): (num_samples, num_weights) ), "windows": [0, ts.get_sequence_length()], + "nodes": list(ts.get_samples()), } return ts, params @@ -2690,18 +2691,38 @@ def get_example(self, num_weights=2): def test_basic_example(self, mode, num_weights): ts, params = self.get_example(num_weights) ns = ts.get_num_samples() - result = ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, True, True - ) - assert result.shape == (1, ns, num_weights) - result = ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, True, False - ) - assert result.shape == (1, ns, num_weights) - result = ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, False, True - ) - assert result.shape == (1, ns, num_weights) + params["mode"] = mode + for a, b in ([True, True], [True, False], [False, True]): + params["span_normalise"] = a + params["centre"] = b + result = ts.genetic_relatedness_vector(**params) + assert result.shape == (1, ns, num_weights) + + @pytest.mark.parametrize("mode", ["branch"]) + def test_good_nodes(self, mode): + num_weights = 2 + ts, params = self.get_example(num_weights) + params["mode"] = mode + for nodes in [ + list(ts.get_samples())[:3], + list(ts.get_samples())[:1], + [0, ts.get_num_nodes() - 1], + ]: + params["nodes"] = nodes + result = ts.genetic_relatedness_vector(**params) + assert result.shape == (1, len(nodes), num_weights) + + def test_bad_nodes(self): + ts, params = self.get_example() + params["mode"] = "branch" + for nodes in ["abc", [[1, 2]]]: + params["nodes"] = nodes + with pytest.raises(ValueError, match="desired array"): + ts.genetic_relatedness_vector(**params) + for nodes in [[-1, 3], [3, 2 * ts.get_num_nodes()]]: + params["nodes"] = nodes + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + ts.genetic_relatedness_vector(**params) def test_bad_args(self): ts, params = self.get_example() @@ -2727,10 +2748,9 @@ def test_bad_args(self): @pytest.mark.parametrize("mode", ["site", "node"]) def test_modes_not_supported(self, mode): ts, params = self.get_example() + params["mode"] = mode with pytest.raises(_tskit.LibraryError): - ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, True, True - ) + ts.genetic_relatedness_vector(**params) @pytest.mark.parametrize("mode", ["branch"]) def test_bad_weights(self, mode): @@ -2765,12 +2785,10 @@ def test_window_errors(self): L = ts.get_sequence_length() bad_windows = [ [L, 0], - [0.1, L], [-1, L], [0, L + 0.1], [0, 0.1, 0.1, L], [0, -1, L], - [0, 0.1, 0.05, 0.2, L], ] for bad_window in bad_windows: with pytest.raises(_tskit.LibraryError): diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index e6ee4555d3..f765c75c9f 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -27,6 +27,7 @@ import pytest import tskit +from tests import tsutil from tests.test_highlevel import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when @@ -45,14 +46,14 @@ def __init__( windows, num_nodes, samples, + focal_nodes, nodes_time, edges_left, edges_right, edges_parent, edges_child, - edge_insertion_order, - edge_removal_order, sequence_length, + tree_pos, verbosity=0, internal_checks=False, centre=True, @@ -67,12 +68,12 @@ def __init__( self.edges_right = edges_right self.edges_parent = edges_parent self.edges_child = edges_child - self.edge_insertion_order = edge_insertion_order - self.edge_removal_order = edge_removal_order self.sequence_length = sequence_length self.nodes_time = nodes_time self.samples = samples - self.position = 0.0 + self.focal_nodes = focal_nodes + self.tree_pos = tree_pos + self.position = windows[0] self.x = np.zeros(N, dtype=np.float64) self.w = np.zeros((N, self.num_weights), dtype=np.float64) self.v = np.zeros((N, self.num_weights), dtype=np.float64) @@ -92,6 +93,8 @@ def __init__( def print_state(self, msg=""): num_nodes = len(self.parent) print(f"..........{msg}................") + print("tree_pos:") + print(self.tree_pos) print(f"position = {self.position}") for j in range(num_nodes): st = f"{self.nodes_time[j]}" @@ -191,9 +194,9 @@ def write_output(self): Compute and return the current state, zero-ing out all contributions (used for switching between windows). """ - n = len(self.samples) + n = len(self.focal_nodes) out = np.zeros((n, self.num_weights)) - for j, c in enumerate(self.samples): + for j, c in enumerate(self.focal_nodes): while c != tskit.NULL: if self.x[c] != self.position: self.v[c] += self.get_z(c) @@ -209,9 +212,9 @@ def current_state(self): """ if self.verbosity > 2: print("---------------") - n = len(self.samples) + n = len(self.focal_nodes) out = np.zeros((n, self.num_weights)) - for j, a in enumerate(self.samples): + for j, a in enumerate(self.focal_nodes): # edges on the path up from a pa = a while pa != tskit.NULL: @@ -225,38 +228,53 @@ def current_state(self): def run(self): M = self.edges_left.shape[0] - in_order = self.edge_insertion_order - out_order = self.edge_removal_order edges_left = self.edges_left edges_right = self.edges_right edges_parent = self.edges_parent edges_child = self.edges_child + tree_pos = self.tree_pos + in_order = tree_pos.in_range.order + out_order = tree_pos.out_range.order num_windows = len(self.windows) - 1 - out = np.zeros((num_windows,) + self.sample_weights.shape) + out = np.zeros( + (num_windows, len(self.focal_nodes), self.sample_weights.shape[1]) + ) - j = 0 - k = 0 m = 0 - self.position = 0 - - while m < num_windows and k < M and self.position <= self.sequence_length: - while k < M and edges_right[out_order[k]] == self.position: - p = edges_parent[out_order[k]] - c = edges_child[out_order[k]] - self.remove_edge(p, c) - k += 1 - while j < M and edges_left[in_order[j]] == self.position: - p = edges_parent[in_order[j]] - c = edges_child[in_order[j]] + self.position = self.windows[0] + + # seek to first window + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, 1): + e = in_order[j] + if edges_left[e] <= self.position and self.position < edges_right[e]: + p = edges_parent[e] + c = edges_child[e] self.insert_edge(p, c) - assert self.parent[p] == tskit.NULL or self.x[p] == self.position - j += 1 - right = self.windows[m + 1] - if j < M: - right = min(right, edges_left[in_order[j]]) - if k < M: - right = min(right, edges_right[out_order[k]]) - self.position = right + + valid = tree_pos.next() + j = tree_pos.in_range.start - 1 + k = tree_pos.out_range.start - 1 + while m < num_windows: + if valid and self.position == tree_pos.interval.left: + for k in range(tree_pos.out_range.start, tree_pos.out_range.stop, 1): + e = out_order[k] + p = edges_parent[e] + c = edges_child[e] + self.remove_edge(p, c) + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, 1): + e = in_order[j] + p = edges_parent[e] + c = edges_child[e] + self.insert_edge(p, c) + assert self.parent[p] == tskit.NULL or self.x[p] == self.position + valid = tree_pos.next() + next_position = self.windows[m + 1] + if j + 1 < M: + next_position = min(next_position, edges_left[in_order[j + 1]]) + if k + 1 < M: + next_position = min(next_position, edges_right[out_order[k + 1]]) + assert self.position < next_position + self.position = next_position if self.position == self.windows[m + 1]: out[m] = self.write_output() m = m + 1 @@ -270,25 +288,35 @@ def run(self): return out -def relatedness_vector(ts, sample_weights, windows=None, **kwargs): +def relatedness_vector(ts, sample_weights, windows=None, nodes=None, **kwargs): if len(sample_weights.shape) == 1: sample_weights = sample_weights[:, np.newaxis] + if nodes is None: + nodes = np.fromiter(ts.samples(), dtype=np.int32) drop_dimension = windows is None if drop_dimension: windows = [0, ts.sequence_length] + + tree_pos = tsutil.TreePosition(ts) + breakpoints = np.fromiter(ts.breakpoints(), dtype="float") + index = np.searchsorted(breakpoints, windows[0]) + if breakpoints[index] > windows[0]: + index -= 1 + tree_pos.seek_forward(index) + rv = RelatednessVector( sample_weights, windows, ts.num_nodes, samples=ts.samples(), + focal_nodes=nodes, nodes_time=ts.nodes_time, edges_left=ts.edges_left, edges_right=ts.edges_right, edges_parent=ts.edges_parent, edges_child=ts.edges_child, - edge_insertion_order=ts.indexes_edge_insertion_order, - edge_removal_order=ts.indexes_edge_removal_order, sequence_length=ts.sequence_length, + tree_pos=tree_pos, **kwargs, ) out = rv.run() @@ -298,25 +326,56 @@ def relatedness_vector(ts, sample_weights, windows=None, **kwargs): return out -def relatedness_matrix(ts, windows, centre): +def relatedness_matrix(ts, windows, centre, nodes=None): + if nodes is None: + keep_rows = np.arange(ts.num_samples) + keep_cols = np.arange(ts.num_samples) + else: + orig_samples = list(ts.samples()) + extra_nodes = set(nodes).difference(set(orig_samples)) + tables = ts.dump_tables() + tables.nodes.clear() + for n in ts.nodes(): + if n.id in extra_nodes: + n = n.replace(flags=n.flags | tskit.NODE_IS_SAMPLE) + tables.nodes.append(n) + ts = tables.tree_sequence() + all_samples = list(ts.samples()) + keep_rows = np.array([all_samples.index(i) for i in nodes]) + keep_cols = np.array([all_samples.index(i) for i in orig_samples]) + + use_windows = windows + drop_first = windows is not None and windows[0] > 0 + if drop_first: + use_windows = np.concatenate([[0], np.array(use_windows)]) + drop_last = windows is not None and windows[-1] < ts.sequence_length + if drop_last: + use_windows = np.concatenate([np.array(use_windows), [ts.sequence_length]]) Sigma = ts.genetic_relatedness( sample_sets=[[i] for i in ts.samples()], indexes=[(i, j) for i in range(ts.num_samples) for j in range(ts.num_samples)], - windows=windows, + windows=use_windows, mode="branch", span_normalise=False, proportion=False, centre=centre, ) if windows is not None: - shape = (len(windows) - 1, ts.num_samples, ts.num_samples) - else: - shape = (ts.num_samples, ts.num_samples) - return Sigma.reshape(shape) + if drop_first: + Sigma = Sigma[1:] + if drop_last: + Sigma = Sigma[:-1] + nwin = 1 if windows is None else len(windows) - 1 + shape = (nwin, ts.num_samples, ts.num_samples) + Sigma = Sigma.reshape(shape) + out = np.array([x[np.ix_(keep_rows, keep_cols)] for x in Sigma]) + if windows is None: + out = out[0] + return out def verify_relatedness_vector( - ts, w, windows, *, internal_checks=False, verbosity=0, centre=True + ts, w, windows, *, internal_checks=False, verbosity=0, centre=True, nodes=None ): R1 = relatedness_vector( ts, @@ -325,55 +384,79 @@ def verify_relatedness_vector( internal_checks=internal_checks, verbosity=verbosity, centre=centre, + nodes=nodes, ) + nrows = ts.num_samples if nodes is None else len(nodes) wvec = w if len(w.shape) > 1 else w[:, np.newaxis] - Sigma = relatedness_matrix(ts, windows=windows, centre=centre) + Sigma = relatedness_matrix(ts, windows=windows, centre=centre, nodes=nodes) if windows is None: R2 = Sigma.dot(wvec) else: - R2 = np.zeros((len(windows) - 1, ts.num_samples, wvec.shape[1])) + R2 = np.zeros((len(windows) - 1, nrows, wvec.shape[1])) for k in range(len(windows) - 1): R2[k] = Sigma[k].dot(wvec) - R3 = ts.genetic_relatedness_vector(w, windows=windows, mode="branch", centre=centre) + R3 = ts.genetic_relatedness_vector( + w, windows=windows, mode="branch", centre=centre, nodes=nodes + ) if verbosity > 0: print(ts.draw_text()) print("weights:", w) print("windows:", windows) + print("centre:", centre) print("here:", R1) print("with ts:", R2) print("with lib:", R3) print("Sigma:", Sigma) if windows is None: - assert R1.shape == (ts.num_samples, wvec.shape[1]) + assert R1.shape == (nrows, wvec.shape[1]) else: - assert R1.shape == (len(windows) - 1, ts.num_samples, wvec.shape[1]) - np.testing.assert_allclose(R1, R2, atol=1e-13) - np.testing.assert_allclose(R1, R3, atol=1e-13) + assert R1.shape == (len(windows) - 1, nrows, wvec.shape[1]) + np.testing.assert_allclose(R1, R2, atol=1e-10) + np.testing.assert_allclose(R1, R3, atol=1e-10) return R1 def check_relatedness_vector( - ts, n=2, num_windows=0, *, internal_checks=False, verbosity=0, seed=123, centre=True + ts, + n=2, + num_windows=0, + *, + internal_checks=False, + verbosity=0, + seed=123, + centre=True, + do_nodes=True, ): rng = np.random.default_rng(seed=seed) if num_windows == 0: windows = None + elif num_windows % 2 == 0: + windows = np.linspace( + 0.2 * ts.sequence_length, 0.8 * ts.sequence_length, num_windows + 1 + ) else: windows = np.linspace(0, ts.sequence_length, num_windows + 1) - for k in range(n): - if k == 0: - w = rng.normal(size=ts.num_samples) + num_nodes_list = (0,) if (centre or not do_nodes) else (0, 3) + for num_nodes in num_nodes_list: + if num_nodes == 0: + nodes = None else: - w = rng.normal(size=ts.num_samples * k).reshape((ts.num_samples, k)) - w = np.round(len(w) * w) - R = verify_relatedness_vector( - ts, - w, - windows, - internal_checks=internal_checks, - verbosity=verbosity, - centre=centre, - ) + nodes = rng.choice(ts.num_nodes, num_nodes, replace=False) + for k in range(n): + if k == 0: + w = rng.normal(size=ts.num_samples) + else: + w = rng.normal(size=ts.num_samples * k).reshape((ts.num_samples, k)) + w = np.round(len(w) * w) + R = verify_relatedness_vector( + ts, + w, + windows, + internal_checks=internal_checks, + verbosity=verbosity, + centre=centre, + nodes=nodes, + ) return R @@ -405,10 +488,87 @@ def test_bad_windows(self): np.ones(ts.num_samples), windows=bad_w, mode="branch" ) + def test_nodes_centred_error(self): + ts = msprime.sim_ancestry( + 5, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + with pytest.raises(ValueError, match="must have centre"): + ts.genetic_relatedness_vector( + np.ones(ts.num_samples), mode="branch", centre=True, nodes=[0, 1] + ) + + def test_bad_nodes(self): + n = 5 + ts = msprime.sim_ancestry( + n, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_nodes in ([[]], "foo"): + with pytest.raises(ValueError): + ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=bad_nodes, + ) + for bad_nodes in ([-1, 10], [3, 2 * ts.num_nodes]): + with pytest.raises(tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=bad_nodes, + ) + + def test_good_nodes(self): + n = 5 + ts = msprime.sim_ancestry( + n, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + V0 = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), mode="branch", centre=False + ) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=list(ts.samples()), + ) + np.testing.assert_allclose(V0, V, atol=1e-13) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=np.fromiter(ts.samples(), dtype=np.int32), + ) + np.testing.assert_allclose(V0, V, atol=1e-13) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=np.fromiter(ts.samples(), dtype=np.int64), + ) + np.testing.assert_allclose(V0, V, atol=1e-13) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=list(ts.samples())[:2], + ) + np.testing.assert_allclose(V0[:2], V, atol=1e-13) + @pytest.mark.parametrize("n", [2, 3, 5]) @pytest.mark.parametrize("seed", range(1, 4)) @pytest.mark.parametrize("centre", (True, False)) - @pytest.mark.parametrize("num_windows", (0, 1, 2)) + @pytest.mark.parametrize("num_windows", (0, 1, 2, 3)) def test_small_internal_checks(self, n, seed, centre, num_windows): ts = msprime.sim_ancestry( n, @@ -418,12 +578,14 @@ def test_small_internal_checks(self, n, seed, centre, num_windows): random_seed=seed, ) assert ts.num_trees >= 2 - check_relatedness_vector(ts, internal_checks=True, centre=centre) + check_relatedness_vector( + ts, num_windows=num_windows, internal_checks=True, centre=centre + ) @pytest.mark.parametrize("n", [2, 3, 5, 15]) @pytest.mark.parametrize("seed", range(1, 5)) @pytest.mark.parametrize("centre", (True, False)) - @pytest.mark.parametrize("num_windows", (0, 1, 3)) + @pytest.mark.parametrize("num_windows", (0, 1, 2, 3)) def test_simple_sims(self, n, seed, centre, num_windows): ts = msprime.sim_ancestry( n, @@ -438,6 +600,31 @@ def test_simple_sims(self, n, seed, centre, num_windows): ts, num_windows=num_windows, centre=centre, verbosity=0 ) + def test_simple_sims_windows(self): + L = 100 + ts = msprime.sim_ancestry( + 5, + ploidy=1, + population_size=20, + sequence_length=L, + recombination_rate=0.01, + random_seed=345, + ) + assert ts.num_trees >= 2 + W = np.linspace(0, 1, 2 * ts.num_samples).reshape((ts.num_samples, 2)) + kwargs = {"centre": False, "mode": "branch"} + total = ts.genetic_relatedness_vector(W, **kwargs) + for windows in [[0, L], [0, L / 3, L / 2, L]]: + pieces = ts.genetic_relatedness_vector(W, windows=windows, **kwargs) + np.testing.assert_allclose(total, pieces.sum(axis=0), atol=1e-13) + assert len(pieces) == len(windows) - 1 + for k in range(len(pieces)): + piece = ts.genetic_relatedness_vector( + W, windows=windows[k : k + 2], **kwargs + ) + assert piece.shape[0] == 1 + np.testing.assert_allclose(piece[0], pieces[k], atol=1e-13) + @pytest.mark.parametrize("n", [2, 3, 5, 15]) @pytest.mark.parametrize("centre", (True, False)) def test_single_balanced_tree(self, n, centre): @@ -456,7 +643,7 @@ def test_internal_sample(self, centre): @pytest.mark.parametrize("seed", range(1, 5)) @pytest.mark.parametrize("centre", (True, False)) - @pytest.mark.parametrize("num_windows", (0, 1, 2)) + @pytest.mark.parametrize("num_windows", (0, 1, 2, 3)) def test_one_internal_sample_sims(self, seed, centre, num_windows): ts = msprime.sim_ancestry( 10, @@ -476,7 +663,7 @@ def test_one_internal_sample_sims(self, seed, centre, num_windows): check_relatedness_vector(ts, num_windows=num_windows, centre=centre) @pytest.mark.parametrize("centre", (True, False)) - @pytest.mark.parametrize("num_windows", (0, 1, 2)) + @pytest.mark.parametrize("num_windows", (0, 1, 2, 3)) def test_missing_flanks(self, centre, num_windows): ts = msprime.sim_ancestry( 2, @@ -502,7 +689,7 @@ def test_dangling_on_samples(self, n): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(n).tree_sequence - D1 = check_relatedness_vector(ts1) + D1 = check_relatedness_vector(ts1, do_nodes=False) tables = ts1.dump_tables() for u in ts1.samples(): v = tables.nodes.add_row(time=-1) @@ -510,7 +697,7 @@ def test_dangling_on_samples(self, n): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True) + D2 = check_relatedness_vector(ts2, internal_checks=True, do_nodes=False) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("n", [2, 3, 10]) @@ -519,7 +706,7 @@ def test_dangling_on_all(self, n, centre): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(n).tree_sequence - D1 = check_relatedness_vector(ts1, centre=centre) + D1 = check_relatedness_vector(ts1, centre=centre, do_nodes=False) tables = ts1.dump_tables() for u in range(ts1.num_nodes): v = tables.nodes.add_row(time=-1) @@ -527,7 +714,9 @@ def test_dangling_on_all(self, n, centre): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) + D2 = check_relatedness_vector( + ts2, internal_checks=True, centre=centre, do_nodes=False + ) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("centre", (True, False)) @@ -535,7 +724,7 @@ def test_disconnected_non_sample_topology(self, centre): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(5).tree_sequence - D1 = check_relatedness_vector(ts1, centre=centre) + D1 = check_relatedness_vector(ts1, centre=centre, do_nodes=False) tables = ts1.dump_tables() # Add an extra bit of disconnected non-sample topology u = tables.nodes.add_row(time=0) @@ -544,5 +733,7 @@ def test_disconnected_non_sample_topology(self, centre): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) + D2 = check_relatedness_vector( + ts2, internal_checks=True, centre=centre, do_nodes=False + ) np.testing.assert_array_almost_equal(D1, D2) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e617c7ea8d..3d4637829d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7858,10 +7858,20 @@ def __weighted_vector_stat( mode=None, span_normalise=True, centre=True, + nodes=None, ): W = np.asarray(W) if len(W.shape) == 1: W = W.reshape(W.shape[0], 1) + if nodes is None: + nodes = list(self.samples()) + else: + if centre: + raise ValueError("If `nodes` is provided, must have centre=False.") + try: + nodes = util.safe_np_int_cast(nodes, np.int32) + except Exception: + raise ValueError("Could not interpret `nodes` as a list of node IDs.") stat = self.__run_windowed_stat( windows, ll_method, @@ -7869,6 +7879,7 @@ def __weighted_vector_stat( mode=mode, span_normalise=span_normalise, centre=centre, + nodes=nodes, ) return stat @@ -8518,20 +8529,31 @@ def genetic_relatedness_vector( mode="site", span_normalise=True, centre=True, + nodes=None, ): r""" Computes the product of the genetic relatedness matrix and a vector of weights (one per sample). The output is a (num windows) x (num samples) x (num weights) - array whose :math:`(i,j)`-th element is :math:`\sum_{b} W_{bj} C_{ib}`, + array whose :math:`(w,i,j)`-th element is :math:`\sum_{b} W_{bj} C_{ib}`, where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample - a and sample b, and the sum is over all samples in the tree sequence. - Like other statistics, if windows is None, the first dimension in the output is - dropped. + `a` and sample `b` in window `w`, and the sum is over all samples in the tree + sequence. Like other statistics, if windows is None, the first dimension in + the output is dropped. The relatedness used here corresponds to `polarised=True`; no unpolarised option is available for this method. + Optionally, you may provide a list of focal nodes that modifies the behavior + as follows. If `nodes` is a list of `n` node IDs (that do not need to be + samples), then the output will have dimension (num windows) x n x (num weights), + and the matrix :math:`C` used in the definition above is the rectangular matrix + with :math:`C_{ij}` the relatedness between `nodes[i]` and `samples[j]`. This + can only be used with `centre=False`; if relatedness between uncentred nodes + and centred samples is desired, then simply subtract column means from `W` first. + The default is `nodes=None`, which is equivalent to setting `nodes` equal to + `ts.samples()`. + :param numpy.ndarray W: An array of values with one row for each sample node and one column for each set of weights. :param list windows: An increasing list of breakpoints between the windows @@ -8542,6 +8564,8 @@ def genetic_relatedness_vector( window (defaults to True). :param bool centre: Whether to use the *centred* relatedness matrix or not: see :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>`. + :param list nodes: Optionally, a list of focal nodes as described above + (default: None). :return: A ndarray with shape equal to (num windows, num samples, num weights), or (num samples, num weights) if windows is None. """ @@ -8549,6 +8573,7 @@ def genetic_relatedness_vector( raise ValueError( "First weight dimension must be equal to number of samples." ) + out = self.__weighted_vector_stat( self._ll_tree_sequence.genetic_relatedness_vector, W, @@ -8556,6 +8581,7 @@ def genetic_relatedness_vector( mode=mode, span_normalise=span_normalise, centre=centre, + nodes=nodes, ) return out