diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 2bad40c5b2..d1cba394ac 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9651,7 +9651,8 @@ TreeSequence_weighted_stat_vector_method( TreeSequence *self, PyObject *args, PyObject *kwds, weighted_vector_method *method) { PyObject *ret = NULL; - static char *kwlist[] = { "weights", "windows", "mode", "span_normalise", NULL }; + static char *kwlist[] + = { "weights", "windows", "mode", "span_normalise", "centre", NULL }; PyObject *weights = NULL; PyObject *windows = NULL; PyArrayObject *weights_array = NULL; @@ -9663,13 +9664,14 @@ TreeSequence_weighted_stat_vector_method( tsk_size_t num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); char *mode = NULL; int span_normalise = true; + int centre = true; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords( - args, kwds, "OO|si", kwlist, &weights, &windows, &mode, &span_normalise)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|sii", kwlist, &weights, &windows, + &mode, &span_normalise, ¢re)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -9678,6 +9680,9 @@ TreeSequence_weighted_stat_vector_method( if (span_normalise) { options |= TSK_STAT_SPAN_NORMALISE; } + if (!centre) { + options |= TSK_STAT_NONCENTRED; + } if (parse_windows(windows, &windows_array, &num_windows) != 0) { goto out; } diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index 152fce372e..6d9d26e23c 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -54,8 +54,10 @@ def __init__( sequence_length, verbosity=0, internal_checks=False, + centre=True, ): self.sample_weights = np.asarray(sample_weights, dtype=np.float64) + num_weights = self.sample_weights.shape[1] # virtual root is at num_nodes; virtual samples are beyond that N = num_nodes + 1 + len(samples) self.parent = np.full(N, -1, dtype=np.int32) @@ -72,10 +74,14 @@ def __init__( self.position = 0 self.virtual_root = num_nodes self.x = np.zeros(N, dtype=np.float64) - self.w = np.zeros(N, dtype=np.float64) - self.v = np.zeros(N, dtype=np.float64) + self.w = np.zeros((N, num_weights), dtype=np.float64) + self.v = np.zeros((N, num_weights), dtype=np.float64) self.verbosity = verbosity self.internal_checks = internal_checks + self.centre = centre + + if self.centre: + self.sample_weights -= np.mean(self.sample_weights, axis=0) for j, u in enumerate(samples): self.w[u] = self.sample_weights[j] @@ -254,10 +260,12 @@ def run(self): if self.verbosity > 1: self.print_state() - out = np.zeros(len(self.samples)) + out = np.zeros(self.sample_weights.shape) for out_i in range(len(self.samples)): i = out_i + self.virtual_root + 1 out[out_i] = self.v[i] + if self.centre: + out -= np.mean(out, axis=0) return out @@ -279,28 +287,31 @@ def relatedness_vector(ts, sample_weights, **kwargs): return rv.run() -def relatedness_matrix(ts): +def relatedness_matrix(ts, centre): Sigma = ts.genetic_relatedness( sample_sets=[[i] for i in ts.samples()], indexes=[(i, j) for i in range(ts.num_samples) for j in range(ts.num_samples)], mode="branch", span_normalise=False, proportion=False, - centre=False, + centre=centre, ).reshape((ts.num_samples, ts.num_samples)) return Sigma def verify_relatedness_vector( - ts, w, *, internal_checks=False, verbosity=0, centre=False + ts, w, *, internal_checks=False, verbosity=0, centre=True ): - w = np.round(len(w) * w) R1 = relatedness_vector( - ts, sample_weights=w, internal_checks=internal_checks, verbosity=verbosity + ts, + sample_weights=w, + internal_checks=internal_checks, + verbosity=verbosity, + centre=centre, ) - Sigma = relatedness_matrix(ts) + Sigma = relatedness_matrix(ts, centre=centre) R2 = Sigma.dot(w) - R3 = ts.genetic_relatedness_vector(w, mode="branch") + # R3 = ts.genetic_relatedness_vector(w, mode="branch", centre=centre) if verbosity > 0: print(ts.draw_text()) print("weights:", w) @@ -308,16 +319,23 @@ def verify_relatedness_vector( print("with ts:", R2) print("Sigma:", Sigma) np.testing.assert_allclose(R1, R2, atol=1e-14) - np.testing.assert_allclose(R1, R3) + # np.testing.assert_allclose(R1, R3) return R1 -def check_relatedness_vector(ts, n=5, *, internal_checks=False, verbosity=0, seed=123): +def check_relatedness_vector( + ts, n=5, *, internal_checks=False, verbosity=0, seed=123, centre=True +): rng = np.random.default_rng(seed=seed) - for _ in range(n): - w = rng.normal(size=ts.num_samples) + for k in range(n): + w = rng.normal(size=ts.num_samples * k).reshape((ts.num_samples, k)) + w = np.round(len(w) * w) R = verify_relatedness_vector( - ts, w, internal_checks=internal_checks, verbosity=verbosity + ts, + w, + internal_checks=internal_checks, + verbosity=verbosity, + centre=centre, ) return R @@ -325,7 +343,8 @@ def check_relatedness_vector(ts, n=5, *, internal_checks=False, verbosity=0, see class TestExamples: @pytest.mark.parametrize("n", [2, 3, 5]) @pytest.mark.parametrize("seed", range(1, 4)) - def test_small_internal_checks(self, n, seed): + @pytest.mark.parametrize("centre", (True, False)) + def test_small_internal_checks(self, n, seed, centre): ts = msprime.sim_ancestry( n, ploidy=1, @@ -334,11 +353,12 @@ def test_small_internal_checks(self, n, seed): random_seed=seed, ) assert ts.num_trees >= 2 - check_relatedness_vector(ts, internal_checks=True) + check_relatedness_vector(ts, internal_checks=True, centre=centre) @pytest.mark.parametrize("n", [2, 3, 5, 15]) @pytest.mark.parametrize("seed", range(1, 5)) - def test_simple_sims(self, n, seed): + @pytest.mark.parametrize("centre", (True, False)) + def test_simple_sims(self, n, seed, centre): ts = msprime.sim_ancestry( n, ploidy=1, @@ -348,25 +368,28 @@ def test_simple_sims(self, n, seed): random_seed=seed, ) assert ts.num_trees >= 2 - check_relatedness_vector(ts) + check_relatedness_vector(ts, centre=centre) @pytest.mark.parametrize("n", [2, 3, 5, 15]) - def test_single_balanced_tree(self, n): + @pytest.mark.parametrize("centre", (True, False)) + def test_single_balanced_tree(self, n, centre): ts = tskit.Tree.generate_balanced(n).tree_sequence - check_relatedness_vector(ts, internal_checks=True, verbosity=1) + check_relatedness_vector(ts, internal_checks=True, centre=centre) - def test_internal_sample(self): + @pytest.mark.parametrize("centre", (True, False)) + def test_internal_sample(self, centre): tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() flags = tables.nodes.flags flags[3] = 0 flags[5] = tskit.NODE_IS_SAMPLE tables.nodes.flags = flags ts = tables.tree_sequence() - check_relatedness_vector(ts, verbosity=0) + check_relatedness_vector(ts, centre=centre) # @pytest.mark.skip() @pytest.mark.parametrize("seed", range(1, 5)) - def test_one_internal_sample_sims(self, seed): + @pytest.mark.parametrize("centre", (True, False)) + def test_one_internal_sample_sims(self, seed, centre): ts = msprime.sim_ancestry( 10, ploidy=1, @@ -382,9 +405,10 @@ def test_one_internal_sample_sims(self, seed): t.sort() t.build_index() ts = t.tree_sequence() - check_relatedness_vector(ts) + check_relatedness_vector(ts, centre=centre) - def test_missing_flanks(self): + @pytest.mark.parametrize("centre", (True, False)) + def test_missing_flanks(self, centre): ts = msprime.sim_ancestry( 2, ploidy=1, @@ -396,12 +420,13 @@ def test_missing_flanks(self): assert ts.num_trees >= 2 ts = ts.keep_intervals([[20, 80]]) assert ts.first().interval == (0, 20) - check_relatedness_vector(ts, verbosity=2) + check_relatedness_vector(ts, centre=centre) @pytest.mark.parametrize("ts", get_example_tree_sequences()) - def test_suite_examples(self, ts): + @pytest.mark.parametrize("centre", (True, False)) + def test_suite_examples(self, ts, centre): if ts.num_samples > 0: - check_relatedness_vector(ts) + check_relatedness_vector(ts, centre=centre) @pytest.mark.parametrize("n", [2, 3, 10]) def test_dangling_on_samples(self, n): @@ -420,11 +445,12 @@ def test_dangling_on_samples(self, n): np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("n", [2, 3, 10]) - def test_dangling_on_all(self, n): + @pytest.mark.parametrize("centre", (True, False)) + def test_dangling_on_all(self, n, centre): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(n).tree_sequence - D1 = check_relatedness_vector(ts1) + D1 = check_relatedness_vector(ts1, centre=centre) tables = ts1.dump_tables() for u in range(ts1.num_nodes): v = tables.nodes.add_row(time=-1) @@ -432,14 +458,15 @@ def test_dangling_on_all(self, n): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True) + D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) np.testing.assert_array_almost_equal(D1, D2) - def test_disconnected_non_sample_topology(self): + @pytest.mark.parametrize("centre", (True, False)) + def test_disconnected_non_sample_topology(self, centre): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(5).tree_sequence - D1 = check_relatedness_vector(ts1) + D1 = check_relatedness_vector(ts1, centre=centre) tables = ts1.dump_tables() # Add an extra bit of disconnected non-sample topology u = tables.nodes.add_row(time=0) @@ -448,5 +475,5 @@ def test_disconnected_non_sample_topology(self): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True) + D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) np.testing.assert_array_almost_equal(D1, D2) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 06158269db..f4946c0154 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7771,6 +7771,7 @@ def __weighted_vector_stat( windows=None, mode=None, span_normalise=True, + centre=True, ): W = np.asarray(W) if len(W.shape) == 1: @@ -7781,6 +7782,7 @@ def __weighted_vector_stat( W, mode=mode, span_normalise=span_normalise, + centre=centre, ) return stat @@ -8429,7 +8431,7 @@ def genetic_relatedness_vector( windows=None, mode="site", span_normalise=True, - polarised=True, + centre=True, ): r""" Computes the product of the genetic relatedness matrix and a vector of weights @@ -8439,6 +8441,9 @@ def genetic_relatedness_vector( :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample a and sample b, and the sum is over all samples in the tree sequence. + The relatedness used here corresponds to `polarised=True`; no unpolarised option + is available for this method. + :param numpy.ndarray W: An array of values with one row for each sample node and one column for each set of weights. :param list windows: An increasing list of breakpoints between the windows @@ -8447,20 +8452,21 @@ def genetic_relatedness_vector( (defaults to "site"). :param bool span_normalise: Whether to divide the result by the span of the window (defaults to True). + :param bool centre: Whether to use the *centred* relatedness matrix or not: + see :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>`. :return: A ndarray with shape equal to (num windows, num weights). """ if len(W) != self.num_samples: raise ValueError( "First trait dimension must be equal to number of samples." ) - if not polarised: - raise ValueError("genetic_relatedness_vector is not available unpolarised.") return self.__weighted_vector_stat( self._ll_tree_sequence.genetic_relatedness_vector, W, windows=windows, mode=mode, span_normalise=span_normalise, + centre=centre, ) def trait_covariance(self, W, windows=None, mode="site", span_normalise=True):