diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 0e5c2e17a1..7dccc2346d 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1992,6 +1992,176 @@ test_paper_ex_genetic_relatedness_weighted_errors(void) tsk_treeseq_free(&ts); } +static void +test_empty_genetic_relatedness_vector(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_size_t num_samples; + double *weights, *result; + tsk_size_t j; + tsk_size_t num_weights = 2; + double windows[] = { 0, 0 }; + + tsk_treeseq_from_text( + &ts, 1, single_tree_ex_nodes, "", NULL, NULL, NULL, NULL, NULL, 0); + num_samples = tsk_treeseq_get_num_samples(&ts); + windows[1] = tsk_treeseq_get_sequence_length(&ts); + weights = tsk_malloc(num_weights * num_samples * sizeof(double)); + result = tsk_malloc(num_weights * num_samples * sizeof(double)); + for (j = 0; j < num_samples; j++) { + weights[j] = 1.0; + } + for (j = 0; j < num_samples; j++) { + weights[j + num_samples] = (float) j; + } + + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, num_weights, weights, 1, windows, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, num_weights, weights, 1, windows, result, TSK_STAT_NONCENTRED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_treeseq_free(&ts); + free(weights); + free(result); +} + +static void +verify_genetic_relatedness_vector( + tsk_treeseq_t *ts, tsk_size_t num_weights, tsk_size_t num_windows) +{ + int ret; + tsk_size_t num_samples; + double *weights, *result; + tsk_size_t j, k; + double *windows = tsk_malloc((num_windows + 1) * sizeof(*windows)); + double L = tsk_treeseq_get_sequence_length(ts); + + windows[0] = 0; + windows[num_windows] = L; + for (j = 1; j < num_windows; j++) { + windows[j] = ((double) j) * L / (double) num_windows; + } + num_samples = tsk_treeseq_get_num_samples(ts); + + weights = tsk_malloc(num_weights * num_samples * sizeof(*weights)); + result = tsk_malloc(num_windows * num_weights * num_samples * sizeof(*result)); + for (j = 0; j < num_samples; j++) { + for (k = 0; k < num_weights; k++) { + weights[j + k * num_samples] = 1.0 + (double) k; + } + } + + ret = tsk_treeseq_genetic_relatedness_vector( + ts, num_weights, weights, num_windows, windows, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_genetic_relatedness_vector( + ts, num_weights, weights, num_windows, windows, result, TSK_STAT_NONCENTRED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_set_debug_stream(_devnull); + ret = tsk_treeseq_genetic_relatedness_vector( + ts, num_weights, weights, num_windows, windows, result, TSK_DEBUG); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_set_debug_stream(stdout); + + free(windows); + free(weights); + free(result); +} + +static void +test_paper_ex_genetic_relatedness_vector(void) +{ + tsk_treeseq_t ts; + double gap; + + for (gap = 0.0; gap < 2.0; gap += 1.0) { + tsk_treeseq_from_text(&ts, 10 + gap, paper_ex_nodes, paper_ex_edges, NULL, + paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); + + tsk_size_t j, k; + for (j = 1; j < 3; j++) { + for (k = 1; k < 3; k++) { + verify_genetic_relatedness_vector(&ts, j, k); + } + } + tsk_treeseq_free(&ts); + } +} + +static void +test_paper_ex_genetic_relatedness_vector_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_size_t num_samples; + double *weights, *result; + tsk_size_t j; + tsk_size_t num_weights = 2; + double windows[] = { 0, 0, 0 }; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + num_samples = tsk_treeseq_get_num_samples(&ts); + + weights = tsk_malloc(num_weights * num_samples * sizeof(double)); + result = tsk_malloc(num_weights * num_samples * sizeof(double)); + for (j = 0; j < num_samples; j++) { + weights[j] = 1.0; + } + for (j = 0; j < num_samples; j++) { + weights[j + num_samples] = (float) j; + } + + /* Window errors */ + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 0, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 0, NULL, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = -1; + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 10; + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0; + windows[2] = 12; + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + /* unsupported mode errors */ + windows[0] = 0.0; + windows[1] = 5.0; + windows[2] = 10.0; + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, num_weights, weights, 2, windows, result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, num_weights, weights, 2, windows, result, TSK_STAT_NODE); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + + tsk_treeseq_free(&ts); + free(weights); + free(result); +} + static void test_paper_ex_Y2_errors(void) { @@ -3532,6 +3702,12 @@ main(int argc, char **argv) test_paper_ex_genetic_relatedness_weighted }, { "test_paper_ex_genetic_relatedness_weighted_errors", test_paper_ex_genetic_relatedness_weighted_errors }, + { "test_empty_genetic_relatedness_vector", + test_empty_genetic_relatedness_vector }, + { "test_paper_ex_genetic_relatedness_vector", + test_paper_ex_genetic_relatedness_vector }, + { "test_paper_ex_genetic_relatedness_vector_errors", + test_paper_ex_genetic_relatedness_vector_errors }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 6119081a78..9462dfc51c 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -9896,3 +9896,370 @@ tsk_treeseq_pair_coalescence_rates(const tsk_treeseq_t *self, tsk_size_t num_sam out: return ret; } + +/* ======================================================== * + * Relatedness matrix-vector product + * ======================================================== */ + +typedef struct { + const tsk_treeseq_t *ts; + tsk_size_t num_weights; + const double *weights; + tsk_size_t num_windows; + const double *windows; + tsk_flags_t options; + double *result; + /* tree */ + double tree_left; + tsk_size_t num_nodes; + tsk_id_t *parent; + double *x; + double *w; + double *v; +} tsk_matvec_calculator_t; + +static void +tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out) +{ + tsk_id_t j; + tsk_size_t num_samples = tsk_treeseq_get_num_samples(self->ts); + + fprintf(out, "Matvec state:\n"); + fprintf(out, "options = %d\n", self->options); + fprintf(out, "tree_left = %f\n", self->tree_left); + fprintf(out, "samples = %lld: [", (long long) num_samples); + fprintf(out, "]\n"); + fprintf(out, "node\tparent\tx\tv\tw"); + fprintf(out, "\n"); + + for (j = 0; j < (tsk_id_t) self->num_nodes; j++) { + fprintf(out, "%lld\t", (long long) j); + fprintf(out, "%lld\t%g\t%g\t%g\n", (long long) self->parent[j], self->x[j], + self->v[j], self->w[j]); + } +} + +static int +tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *ts, + tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) +{ + int ret = 0; + tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); + const tsk_size_t num_nodes = ts->tables->nodes.num_rows; + const double *row; + double *new_row; + tsk_size_t k; + tsk_id_t u, j; + double *weight_means = tsk_malloc(num_weights * sizeof(*weight_means)); + + self->ts = ts; + self->tree_left = 0.0; + self->num_weights = num_weights; + self->weights = weights; + self->num_windows = num_windows; + self->windows = windows; + self->options = options; + self->result = result; + self->num_nodes = num_nodes; + + self->parent = tsk_malloc(num_nodes * sizeof(*self->parent)); + self->x = tsk_calloc(num_nodes, sizeof(*self->x)); + self->v = tsk_calloc(num_nodes * num_weights, sizeof(*self->v)); + self->w = tsk_calloc(num_nodes * num_weights, sizeof(*self->w)); + + if (self->parent == NULL || self->x == NULL || self->w == NULL || self->v == NULL + || weight_means == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + tsk_memset(result, 0, num_windows * num_samples * num_weights * sizeof(*result)); + tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); + + for (k = 0; k < num_weights; k++) { + weight_means[k] = 0.0; + } + if (!(options & TSK_STAT_NONCENTRED)) { + for (j = 0; j < (tsk_id_t) num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + for (k = 0; k < num_weights; k++) { + weight_means[k] += row[k]; + } + } + for (k = 0; k < num_weights; k++) { + weight_means[k] /= (double) num_samples; + } + } + + for (j = 0; j < (tsk_id_t) num_samples; j++) { + u = ts->samples[j]; + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(self->w, num_weights, u); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k] - weight_means[k]; + } + } +out: + tsk_safe_free(weight_means); + return ret; +} + +static int +tsk_matvec_calculator_free(tsk_matvec_calculator_t *self) +{ + tsk_safe_free(self->parent); + tsk_safe_free(self->x); + tsk_safe_free(self->w); + tsk_safe_free(self->v); + + /* Make this safe for multiple free calls */ + memset(self, 0, sizeof(*self)); + return 0; +} + +static inline void +tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double tree_left, + double *restrict x, const tsk_size_t num_weights, double *restrict w, + double *restrict v, const double *restrict nodes_time) +{ + double t, span; + tsk_size_t j; + double *restrict v_row, *restrict w_row; + + if (p != TSK_NULL) { + t = nodes_time[p] - nodes_time[u]; + span = tree_left - x[u]; + // do this: self->v[u] += t * span * self->w[u]; + w_row = GET_2D_ROW(w, num_weights, u); + v_row = GET_2D_ROW(v, num_weights, u); + for (j = 0; j < num_weights; j++) { + v_row[j] += t * span * w_row[j]; + } + } + x[u] = tree_left; +} + +static void +tsk_matvec_calculator_adjust_path_up( + tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c, double sign) +{ + tsk_size_t j; + double *p_row, *c_row; + const tsk_id_t *restrict parent = self->parent; + const double tree_left = self->tree_left; + double *restrict x = self->x; + const tsk_size_t num_weights = self->num_weights; + double *restrict w = self->w; + double *restrict v = self->v; + const double *restrict nodes_time = self->ts->tables->nodes.time; + + // sign = -1 for removing edges, +1 for adding + while (p != TSK_NULL) { + tsk_matvec_calculator_add_z( + p, parent[p], tree_left, x, num_weights, w, v, nodes_time); + // do this: self->v[c] -= sign * self->v[p]; + p_row = GET_2D_ROW(v, num_weights, p); + c_row = GET_2D_ROW(v, num_weights, c); + for (j = 0; j < num_weights; j++) { + c_row[j] -= sign * p_row[j]; + } + // do this: self->w[p] += sign * self->w[c]; + p_row = GET_2D_ROW(w, num_weights, p); + c_row = GET_2D_ROW(w, num_weights, c); + for (j = 0; j < num_weights; j++) { + p_row[j] += sign * c_row[j]; + } + p = parent[p]; + } +} + +static void +tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c) +{ + tsk_id_t *parent = self->parent; + const double tree_left = self->tree_left; + double *restrict x = self->x; + const tsk_size_t num_weights = self->num_weights; + double *restrict w = self->w; + double *restrict v = self->v; + const double *restrict nodes_time = self->ts->tables->nodes.time; + + tsk_matvec_calculator_add_z( + c, parent[c], tree_left, x, num_weights, w, v, nodes_time); + parent[c] = TSK_NULL; + tsk_matvec_calculator_adjust_path_up(self, p, c, -1); +} + +static void +tsk_matvec_calculator_insert_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c) +{ + tsk_id_t *parent = self->parent; + + tsk_matvec_calculator_adjust_path_up(self, p, c, +1); + self->x[c] = self->tree_left; + parent[c] = p; +} + +static int +tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restrict y) +{ + int ret = 0; + tsk_id_t u; + tsk_size_t j, k; + tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); + const tsk_size_t num_weights = self->num_weights; + const double tree_left = self->tree_left; + double *u_row, *out_row; + double *out_means = tsk_malloc(num_weights * sizeof(*out_means)); + const tsk_id_t *restrict parent = self->parent; + const double *restrict nodes_time = self->ts->tables->nodes.time; + double *restrict x = self->x; + double *restrict w = self->w; + double *restrict v = self->v; + const tsk_id_t *restrict samples = self->ts->samples; + + if (out_means == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (j = 0; j < n; j++) { + out_row = GET_2D_ROW(y, num_weights, j); + u = samples[j]; + while (u != TSK_NULL) { + if (x[u] != tree_left) { + tsk_matvec_calculator_add_z( + u, parent[u], tree_left, x, num_weights, w, v, nodes_time); + } + u_row = GET_2D_ROW(v, num_weights, u); + for (k = 0; k < num_weights; k++) { + out_row[k] += u_row[k]; + } + u = parent[u]; + } + } + + if (!(self->options & TSK_STAT_NONCENTRED)) { + for (k = 0; k < num_weights; k++) { + out_means[k] = 0.0; + } + for (j = 0; j < n; j++) { + out_row = GET_2D_ROW(y, num_weights, j); + for (k = 0; k < num_weights; k++) { + out_means[k] += out_row[k]; + } + } + for (k = 0; k < num_weights; k++) { + out_means[k] /= (double) n; + } + for (j = 0; j < n; j++) { + out_row = GET_2D_ROW(y, num_weights, j); + for (k = 0; k < num_weights; k++) { + out_row[k] -= out_means[k]; + } + } + } + /* zero out v */ + tsk_memset(self->v, 0, self->num_nodes * num_weights * sizeof(*self->v)); +out: + tsk_safe_free(out_means); + return ret; +} + +static int +tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) +{ + int ret = 0; + tsk_size_t j, k, m; + tsk_id_t e, p, c; + tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); + double tree_right; + const double sequence_length = self->ts->tables->sequence_length; + const tsk_size_t num_edges = self->ts->tables->edges.num_rows; + const tsk_id_t *restrict I = self->ts->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = self->ts->tables->indexes.edge_removal_order; + const double *restrict edge_right = self->ts->tables->edges.right; + const double *restrict edge_left = self->ts->tables->edges.left; + const tsk_id_t *restrict edge_child = self->ts->tables->edges.child; + const tsk_id_t *restrict edge_parent = self->ts->tables->edges.parent; + double *restrict out; + + j = 0; + k = 0; + m = 0; + tree_right = sequence_length; + + while ( + m < self->num_windows && k < num_edges && self->tree_left <= sequence_length) { + while (k < num_edges && edge_right[O[k]] == self->tree_left) { + e = O[k]; + p = edge_parent[e]; + c = edge_child[e]; + tsk_matvec_calculator_remove_edge(self, p, c); + k++; + } + while (j < num_edges && edge_left[I[j]] == self->tree_left) { + e = I[j]; + p = edge_parent[e]; + c = edge_child[e]; + tsk_matvec_calculator_insert_edge(self, p, c); + self->x[c] = self->tree_left; + j++; + } + tree_right = self->windows[m + 1]; + if (j < num_edges) { + tree_right = TSK_MIN(tree_right, edge_left[I[j]]); + } + if (k < num_edges) { + tree_right = TSK_MIN(tree_right, edge_right[O[k]]); + } + self->tree_left = tree_right; + if (self->tree_left == self->windows[m + 1]) { + out = GET_2D_ROW(self->result, self->num_weights * n, m); + tsk_matvec_calculator_write_output(self, out); + m += 1; + } + if (self->options & TSK_DEBUG) { + tsk_matvec_calculator_print_state(self, tsk_get_debug_stream()); + } + } + + /* out: */ + return ret; +} + +int +tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_windows, const double *windows, double *result, + tsk_flags_t options) +{ + int ret = 0; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_node = !!(options & TSK_STAT_NODE); + tsk_matvec_calculator_t calc; + + memset(&calc, 0, sizeof(calc)); + + if (stat_node || stat_site) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } + ret = tsk_treeseq_check_windows(self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); + if (ret != 0) { + goto out; + } + + ret = tsk_matvec_calculator_init( + &calc, self, num_weights, weights, num_windows, windows, options, result); + if (ret != 0) { + goto out; + } + if (options & TSK_DEBUG) { + tsk_matvec_calculator_print_state(&calc, tsk_get_debug_stream()); + } + ret = tsk_matvec_calculator_run(&calc); +out: + tsk_matvec_calculator_free(&calc); + return ret; +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 7169851cdc..df9cf92850 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1025,6 +1025,16 @@ int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); +/* One way weighted stats with vector output */ + +typedef int weighted_vector_method(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_windows, const double *windows, double *result, + tsk_flags_t options); + +int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, + const double *windows, double *result, tsk_flags_t options); + /* One way sample set stats */ typedef int one_way_sample_stat_method(const tsk_treeseq_t *self, diff --git a/docs/python-api.md b/docs/python-api.md index 12756ff3e3..1713c1344d 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -323,6 +323,7 @@ Single site TreeSequence.genealogical_nearest_neighbours TreeSequence.genetic_relatedness TreeSequence.genetic_relatedness_weighted + TreeSequence.genetic_relatedness_vector TreeSequence.general_stat TreeSequence.segregating_sites TreeSequence.sample_count_stat diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 1c7ee57b0a..6d275a499a 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9646,6 +9646,80 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd return ret; } +static PyObject * +TreeSequence_weighted_stat_vector_method( + TreeSequence *self, PyObject *args, PyObject *kwds, weighted_vector_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] + = { "weights", "windows", "mode", "span_normalise", "centre", NULL }; + PyObject *weights = NULL; + PyObject *windows = NULL; + PyArrayObject *weights_array = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *result_array = NULL; + tsk_size_t num_windows; + npy_intp *w_shape, result_shape[3]; + tsk_flags_t options = 0; + tsk_size_t num_samples; + char *mode = NULL; + int span_normalise = true; + int centre = true; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|sii", kwlist, &weights, &windows, + &mode, &span_normalise, ¢re)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (span_normalise) { + options |= TSK_STAT_SPAN_NORMALISE; + } + if (!centre) { + options |= TSK_STAT_NONCENTRED; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + weights_array = (PyArrayObject *) PyArray_FROMANY( + weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY); + if (weights_array == NULL) { + goto out; + } + w_shape = PyArray_DIMS(weights_array); + if (w_shape[0] != (npy_intp) num_samples) { + PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); + goto out; + } + + result_shape[0] = num_windows; + result_shape[1] = num_samples; + result_shape[2] = w_shape[1]; + result_array = (PyArrayObject *) PyArray_SimpleNew(3, result_shape, NPY_FLOAT64); + if (result_array == NULL) { + goto out; + } + err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), + num_windows, PyArray_DATA(windows_array), PyArray_DATA(result_array), options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(weights_array); + Py_XDECREF(windows_array); + Py_XDECREF(result_array); + return ret; +} + static PyObject * TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method) @@ -9759,6 +9833,14 @@ TreeSequence_genetic_relatedness_weighted( self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted); } +static PyObject * +TreeSequence_genetic_relatedness_vector( + TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_weighted_stat_vector_method( + self, args, kwds, tsk_treeseq_genetic_relatedness_vector); +} + static PyObject * TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -11156,6 +11238,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes genetic relatedness between weighted sums of samples." }, + { .ml_name = "genetic_relatedness_vector", + .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_vector, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes genetic relatedness matrix-vector products." }, { .ml_name = "Y1", .ml_meth = (PyCFunction) TreeSequence_Y1, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 89d38cf5e1..ef8c0b8928 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2673,6 +2673,112 @@ def test_options(self): ) +class TestGeneticRelatednessVector(LowLevelTestCase): + def get_example(self, num_weights=2): + ts = self.get_example_tree_sequence() + num_samples = ts.get_num_samples() + params = { + "weights": np.linspace(0, 1, num_weights * num_samples).reshape( + (num_samples, num_weights) + ), + "windows": [0, ts.get_sequence_length()], + } + return ts, params + + @pytest.mark.parametrize("mode", ["branch"]) + @pytest.mark.parametrize("num_weights", [1, 3]) + def test_basic_example(self, mode, num_weights): + ts, params = self.get_example(num_weights) + ns = ts.get_num_samples() + result = ts.genetic_relatedness_vector( + params["weights"], params["windows"], mode, True, True + ) + assert result.shape == (1, ns, num_weights) + result = ts.genetic_relatedness_vector( + params["weights"], params["windows"], mode, True, False + ) + assert result.shape == (1, ns, num_weights) + result = ts.genetic_relatedness_vector( + params["weights"], params["windows"], mode, False, True + ) + assert result.shape == (1, ns, num_weights) + + def test_bad_args(self): + ts, params = self.get_example() + for mode in ("", "abc"): + with pytest.raises(ValueError, match="stats mode"): + ts.genetic_relatedness_vector( + params["weights"], params["windows"], mode, True, True + ) + for mode in (None, []): + with pytest.raises(TypeError, match="must be str"): + ts.genetic_relatedness_vector( + params["weights"], params["windows"], mode, True, True + ) + with pytest.raises(TypeError, match="cannot be interp"): + ts.genetic_relatedness_vector( + params["weights"], params["windows"], "branch", "yes", True + ) + with pytest.raises(TypeError, match="cannot be interp"): + ts.genetic_relatedness_vector( + params["weights"], params["windows"], "branch", True, "no" + ) + + @pytest.mark.parametrize("mode", ["site", "node"]) + def test_modes_not_supported(self, mode): + ts, params = self.get_example() + with pytest.raises(_tskit.LibraryError): + ts.genetic_relatedness_vector( + params["weights"], params["windows"], mode, True, True + ) + + @pytest.mark.parametrize("mode", ["branch"]) + def test_bad_weights(self, mode): + ts, params = self.get_example() + del params["weights"] + ns = ts.get_num_samples() + for bad_weight_type in [None, [None, None]]: + with pytest.raises(ValueError, match="object of too small depth"): + ts.genetic_relatedness_vector( + weights=bad_weight_type, mode=mode, **params + ) + for bad_weight_shape in [(ns - 1, 1), (ns + 1, 1), (0, 3)]: + with pytest.raises(ValueError, match="First dimension must be num_samples"): + ts.genetic_relatedness_vector( + weights=np.ones(bad_weight_shape), mode=mode, **params + ) + + def test_window_errors(self): + ts, params = self.get_example() + del params["windows"] + for bad_array in ["asdf", None, [[[[]], [[]]]], np.zeros((10, 3, 4))]: + with pytest.raises(ValueError): + ts.genetic_relatedness_vector( + windows=bad_array, mode="branch", **params + ) + + for bad_windows in [[], [0]]: + with pytest.raises(ValueError): + ts.genetic_relatedness_vector( + windows=bad_windows, mode="branch", **params + ) + L = ts.get_sequence_length() + bad_windows = [ + [L, 0], + [0.1, L], + [-1, L], + [0, L + 0.1], + [0, 0.1, 0.1, L], + [0, -1, L], + [0, 0.1, 0.05, 0.2, L], + ] + for bad_window in bad_windows: + with pytest.raises(_tskit.LibraryError): + ts.genetic_relatedness_vector( + windows=bad_window, mode="branch", **params + ) + + class TestGeneralStatsInterface(LowLevelTestCase, StatsInterfaceMixin): """ Tests for the general stats interface. diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py new file mode 100644 index 0000000000..e6ee4555d3 --- /dev/null +++ b/python/tests/test_relatedness_vector.py @@ -0,0 +1,548 @@ +# MIT License +# +# Copyright (c) 2024 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for matrix-vector product stats +""" +import msprime +import numpy as np +import pytest + +import tskit +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +# Implementation note: the class structure here, where we pass in all the +# needed arrays through the constructor was determined by an older version +# in which we used numba acceleration. We could just pass in a reference to +# the tree sequence now, but it is useful to keep track of exactly what we +# require, so leaving it as it is for now. +class RelatednessVector: + def __init__( + self, + sample_weights, + windows, + num_nodes, + samples, + nodes_time, + edges_left, + edges_right, + edges_parent, + edges_child, + edge_insertion_order, + edge_removal_order, + sequence_length, + verbosity=0, + internal_checks=False, + centre=True, + ): + self.sample_weights = np.asarray(sample_weights, dtype=np.float64) + self.num_weights = self.sample_weights.shape[1] + self.windows = windows + N = num_nodes + self.parent = np.full(N, -1, dtype=np.int32) + # Edges and indexes + self.edges_left = edges_left + self.edges_right = edges_right + self.edges_parent = edges_parent + self.edges_child = edges_child + self.edge_insertion_order = edge_insertion_order + self.edge_removal_order = edge_removal_order + self.sequence_length = sequence_length + self.nodes_time = nodes_time + self.samples = samples + self.position = 0.0 + self.x = np.zeros(N, dtype=np.float64) + self.w = np.zeros((N, self.num_weights), dtype=np.float64) + self.v = np.zeros((N, self.num_weights), dtype=np.float64) + self.verbosity = verbosity + self.internal_checks = internal_checks + self.centre = centre + + if self.centre: + self.sample_weights -= np.mean(self.sample_weights, axis=0) + + for j, u in enumerate(samples): + self.w[u] = self.sample_weights[j] + + if self.verbosity > 0: + self.print_state("init") + + def print_state(self, msg=""): + num_nodes = len(self.parent) + print(f"..........{msg}................") + print(f"position = {self.position}") + for j in range(num_nodes): + st = f"{self.nodes_time[j]}" + pt = ( + "NaN" + if self.parent[j] == tskit.NULL + else f"{self.nodes_time[self.parent[j]]}" + ) + print( + f"node {j} -> {self.parent[j]}: " + f"z = ({pt} - {st})" + f" * ({self.position} - {self.x[j]:.2})" + f" * {','.join(map(str, self.w[j].round(2)))}" + f" = {','.join(map(str, self.get_z(j).round(2)))}" + ) + print(f" value: {','.join(map(str, self.v[j].round(2)))}") + roots = [] + fmt = "{:<6}{:>8}\t{}\t{}\t{}" + s = f"roots = {roots}\n" + s += ( + fmt.format( + "node", + "parent", + "value", + "weight", + "z", + ) + + "\n" + ) + for u in range(num_nodes): + u_str = f"{u}" + s += ( + fmt.format( + u_str, + self.parent[u], + ",".join(map(str, self.v[u].round(2))), + ",".join(map(str, self.w[u].round(2))), + ",".join(map(str, self.get_z(u).round(2))), + ) + + "\n" + ) + print(s) + + print("Current state:") + state = self.current_state() + for j, x in enumerate(state): + print(f" {j}: {x}") + print("..........................") + + def remove_edge(self, p, c): + if self.verbosity > 0: + self.print_state(f"remove {int(p), int(c)}") + assert p != -1 + self.v[c] += self.get_z(c) + self.x[c] = self.position + self.parent[c] = -1 + self.adjust_path_up(p, c, -1) + + def insert_edge(self, p, c): + if self.verbosity > 0: + self.print_state(f"insert {int(p), int(c)}") + assert p != -1 + assert self.parent[c] == -1, "contradictory edges" + self.adjust_path_up(p, c, +1) + self.x[c] = self.position + self.parent[c] = p + + def adjust_path_up(self, p, c, sign): + # sign = -1 for removing edges, +1 for adding + while p != tskit.NULL: + self.v[p] += self.get_z(p) + self.x[p] = self.position + self.v[c] -= sign * self.v[p] + self.w[p] += sign * self.w[c] + p = self.parent[p] + + def get_z(self, u): + p = self.parent[u] + if p == tskit.NULL: + return np.zeros(self.num_weights, dtype=np.float64) + time = self.nodes_time[p] - self.nodes_time[u] + span = self.position - self.x[u] + return time * span * self.w[u] + + def mrca(self, a, b): + # just used for `current_state` + aa = [a] + while a != tskit.NULL: + a = self.parent[a] + aa.append(a) + while b not in aa: + b = self.parent[b] + return b + + def write_output(self): + """ + Compute and return the current state, zero-ing out + all contributions (used for switching between windows). + """ + n = len(self.samples) + out = np.zeros((n, self.num_weights)) + for j, c in enumerate(self.samples): + while c != tskit.NULL: + if self.x[c] != self.position: + self.v[c] += self.get_z(c) + self.x[c] = self.position + out[j] += self.v[c] + c = self.parent[c] + self.v *= 0.0 + return out + + def current_state(self): + """ + Compute the current output, for debugging. + """ + if self.verbosity > 2: + print("---------------") + n = len(self.samples) + out = np.zeros((n, self.num_weights)) + for j, a in enumerate(self.samples): + # edges on the path up from a + pa = a + while pa != tskit.NULL: + if self.verbosity > 2: + print("edge:", pa, self.get_z(pa)) + out[j] += self.get_z(pa) + self.v[pa] + pa = self.parent[pa] + if self.verbosity > 2: + print("---------------") + return out + + def run(self): + M = self.edges_left.shape[0] + in_order = self.edge_insertion_order + out_order = self.edge_removal_order + edges_left = self.edges_left + edges_right = self.edges_right + edges_parent = self.edges_parent + edges_child = self.edges_child + num_windows = len(self.windows) - 1 + out = np.zeros((num_windows,) + self.sample_weights.shape) + + j = 0 + k = 0 + m = 0 + self.position = 0 + + while m < num_windows and k < M and self.position <= self.sequence_length: + while k < M and edges_right[out_order[k]] == self.position: + p = edges_parent[out_order[k]] + c = edges_child[out_order[k]] + self.remove_edge(p, c) + k += 1 + while j < M and edges_left[in_order[j]] == self.position: + p = edges_parent[in_order[j]] + c = edges_child[in_order[j]] + self.insert_edge(p, c) + assert self.parent[p] == tskit.NULL or self.x[p] == self.position + j += 1 + right = self.windows[m + 1] + if j < M: + right = min(right, edges_left[in_order[j]]) + if k < M: + right = min(right, edges_right[out_order[k]]) + self.position = right + if self.position == self.windows[m + 1]: + out[m] = self.write_output() + m = m + 1 + + if self.verbosity > 1: + self.print_state() + + if self.centre: + for m in range(num_windows): + out[m] -= np.mean(out[m], axis=0) + return out + + +def relatedness_vector(ts, sample_weights, windows=None, **kwargs): + if len(sample_weights.shape) == 1: + sample_weights = sample_weights[:, np.newaxis] + drop_dimension = windows is None + if drop_dimension: + windows = [0, ts.sequence_length] + rv = RelatednessVector( + sample_weights, + windows, + ts.num_nodes, + samples=ts.samples(), + nodes_time=ts.nodes_time, + edges_left=ts.edges_left, + edges_right=ts.edges_right, + edges_parent=ts.edges_parent, + edges_child=ts.edges_child, + edge_insertion_order=ts.indexes_edge_insertion_order, + edge_removal_order=ts.indexes_edge_removal_order, + sequence_length=ts.sequence_length, + **kwargs, + ) + out = rv.run() + if drop_dimension: + assert len(out.shape) == 3 and out.shape[0] == 1 + out = out[0] + return out + + +def relatedness_matrix(ts, windows, centre): + Sigma = ts.genetic_relatedness( + sample_sets=[[i] for i in ts.samples()], + indexes=[(i, j) for i in range(ts.num_samples) for j in range(ts.num_samples)], + windows=windows, + mode="branch", + span_normalise=False, + proportion=False, + centre=centre, + ) + if windows is not None: + shape = (len(windows) - 1, ts.num_samples, ts.num_samples) + else: + shape = (ts.num_samples, ts.num_samples) + return Sigma.reshape(shape) + + +def verify_relatedness_vector( + ts, w, windows, *, internal_checks=False, verbosity=0, centre=True +): + R1 = relatedness_vector( + ts, + sample_weights=w, + windows=windows, + internal_checks=internal_checks, + verbosity=verbosity, + centre=centre, + ) + wvec = w if len(w.shape) > 1 else w[:, np.newaxis] + Sigma = relatedness_matrix(ts, windows=windows, centre=centre) + if windows is None: + R2 = Sigma.dot(wvec) + else: + R2 = np.zeros((len(windows) - 1, ts.num_samples, wvec.shape[1])) + for k in range(len(windows) - 1): + R2[k] = Sigma[k].dot(wvec) + R3 = ts.genetic_relatedness_vector(w, windows=windows, mode="branch", centre=centre) + if verbosity > 0: + print(ts.draw_text()) + print("weights:", w) + print("windows:", windows) + print("here:", R1) + print("with ts:", R2) + print("with lib:", R3) + print("Sigma:", Sigma) + if windows is None: + assert R1.shape == (ts.num_samples, wvec.shape[1]) + else: + assert R1.shape == (len(windows) - 1, ts.num_samples, wvec.shape[1]) + np.testing.assert_allclose(R1, R2, atol=1e-13) + np.testing.assert_allclose(R1, R3, atol=1e-13) + return R1 + + +def check_relatedness_vector( + ts, n=2, num_windows=0, *, internal_checks=False, verbosity=0, seed=123, centre=True +): + rng = np.random.default_rng(seed=seed) + if num_windows == 0: + windows = None + else: + windows = np.linspace(0, ts.sequence_length, num_windows + 1) + for k in range(n): + if k == 0: + w = rng.normal(size=ts.num_samples) + else: + w = rng.normal(size=ts.num_samples * k).reshape((ts.num_samples, k)) + w = np.round(len(w) * w) + R = verify_relatedness_vector( + ts, + w, + windows, + internal_checks=internal_checks, + verbosity=verbosity, + centre=centre, + ) + return R + + +class TestExamples: + + def test_bad_weights(self): + n = 5 + ts = msprime.sim_ancestry( + n, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_W in (None, [1], np.ones((3 * n, 2)), np.ones((n - 1, 2))): + with pytest.raises(ValueError, match="number of samples"): + ts.genetic_relatedness_vector(bad_W, mode="branch") + + def test_bad_windows(self): + n = 5 + ts = msprime.sim_ancestry( + n, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_w in ([1], []): + with pytest.raises(ValueError, match="Windows array"): + ts.genetic_relatedness_vector( + np.ones(ts.num_samples), windows=bad_w, mode="branch" + ) + + @pytest.mark.parametrize("n", [2, 3, 5]) + @pytest.mark.parametrize("seed", range(1, 4)) + @pytest.mark.parametrize("centre", (True, False)) + @pytest.mark.parametrize("num_windows", (0, 1, 2)) + def test_small_internal_checks(self, n, seed, centre, num_windows): + ts = msprime.sim_ancestry( + n, + ploidy=1, + sequence_length=1000, + recombination_rate=0.01, + random_seed=seed, + ) + assert ts.num_trees >= 2 + check_relatedness_vector(ts, internal_checks=True, centre=centre) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("seed", range(1, 5)) + @pytest.mark.parametrize("centre", (True, False)) + @pytest.mark.parametrize("num_windows", (0, 1, 3)) + def test_simple_sims(self, n, seed, centre, num_windows): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=seed, + ) + assert ts.num_trees >= 2 + check_relatedness_vector( + ts, num_windows=num_windows, centre=centre, verbosity=0 + ) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("centre", (True, False)) + def test_single_balanced_tree(self, n, centre): + ts = tskit.Tree.generate_balanced(n).tree_sequence + check_relatedness_vector(ts, internal_checks=True, centre=centre) + + @pytest.mark.parametrize("centre", (True, False)) + def test_internal_sample(self, centre): + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[3] = 0 + flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + check_relatedness_vector(ts, centre=centre) + + @pytest.mark.parametrize("seed", range(1, 5)) + @pytest.mark.parametrize("centre", (True, False)) + @pytest.mark.parametrize("num_windows", (0, 1, 2)) + def test_one_internal_sample_sims(self, seed, centre, num_windows): + ts = msprime.sim_ancestry( + 10, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=seed, + ) + t = ts.dump_tables() + # Add a new sample directly below another sample + u = t.nodes.add_row(time=-1, flags=tskit.NODE_IS_SAMPLE) + t.edges.add_row(parent=0, child=u, left=0, right=ts.sequence_length) + t.sort() + t.build_index() + ts = t.tree_sequence() + check_relatedness_vector(ts, num_windows=num_windows, centre=centre) + + @pytest.mark.parametrize("centre", (True, False)) + @pytest.mark.parametrize("num_windows", (0, 1, 2)) + def test_missing_flanks(self, centre, num_windows): + ts = msprime.sim_ancestry( + 2, + ploidy=1, + population_size=10, + sequence_length=100, + recombination_rate=0.001, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = ts.keep_intervals([[20, 80]]) + assert ts.first().interval == (0, 20) + check_relatedness_vector(ts, num_windows=num_windows, centre=centre) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("centre", (True, False)) + def test_suite_examples(self, ts, centre): + if ts.num_samples > 0: + check_relatedness_vector(ts, centre=centre) + + @pytest.mark.parametrize("n", [2, 3, 10]) + def test_dangling_on_samples(self, n): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + D1 = check_relatedness_vector(ts1) + tables = ts1.dump_tables() + for u in ts1.samples(): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_relatedness_vector(ts2, internal_checks=True) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("centre", (True, False)) + def test_dangling_on_all(self, n, centre): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + D1 = check_relatedness_vector(ts1, centre=centre) + tables = ts1.dump_tables() + for u in range(ts1.num_nodes): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("centre", (True, False)) + def test_disconnected_non_sample_topology(self, centre): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(5).tree_sequence + D1 = check_relatedness_vector(ts1, centre=centre) + tables = ts1.dump_tables() + # Add an extra bit of disconnected non-sample topology + u = tables.nodes.add_row(time=0) + v = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=v, child=u) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) + np.testing.assert_array_almost_equal(D1, D2) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 5c4120301a..eea322e4dc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7838,6 +7838,28 @@ def __k_way_weighted_stat( stat = stat.reshape(stat.shape[:-1]) return stat + def __weighted_vector_stat( + self, + ll_method, + W, + windows=None, + mode=None, + span_normalise=True, + centre=True, + ): + W = np.asarray(W) + if len(W.shape) == 1: + W = W.reshape(W.shape[0], 1) + stat = self.__run_windowed_stat( + windows, + ll_method, + W, + mode=mode, + span_normalise=span_normalise, + centre=centre, + ) + return stat + ############################################ # Statistics definitions ############################################ @@ -8477,6 +8499,54 @@ def genetic_relatedness_weighted( centre=centre, ) + def genetic_relatedness_vector( + self, + W, + windows=None, + mode="site", + span_normalise=True, + centre=True, + ): + r""" + Computes the product of the genetic relatedness matrix and a vector of weights + (one per sample). The output is a (num windows) x (num samples) x (num weights) + array whose :math:`(i,j)`-th element is :math:`\sum_{b} W_{bj} C_{ib}`, + where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the + :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample + a and sample b, and the sum is over all samples in the tree sequence. + Like other statistics, if windows is None, the first dimension in the output is + dropped. + + The relatedness used here corresponds to `polarised=True`; no unpolarised option + is available for this method. + + :param numpy.ndarray W: An array of values with one row for each sample node and + one column for each set of weights. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param str mode: A string giving the "type" of the statistic to be computed + (defaults to "site"). + :param bool span_normalise: Whether to divide the result by the span of the + window (defaults to True). + :param bool centre: Whether to use the *centred* relatedness matrix or not: + see :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>`. + :return: A ndarray with shape equal to (num windows, num samples, num weights), + or (num samples, num weights) if windows is None. + """ + if (not hasattr(W, "__len__")) or (len(W) != self.num_samples): + raise ValueError( + "First weight dimension must be equal to number of samples." + ) + out = self.__weighted_vector_stat( + self._ll_tree_sequence.genetic_relatedness_vector, + W, + windows=windows, + mode=mode, + span_normalise=span_normalise, + centre=centre, + ) + return out + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W``