diff --git a/docs/python-api.md b/docs/python-api.md index ac53d3fce9..a8236daadf 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -321,6 +321,7 @@ Single site TreeSequence.Fst TreeSequence.genealogical_nearest_neighbours TreeSequence.genetic_relatedness + TreeSequence.genetic_relatedness_weighted TreeSequence.general_stat TreeSequence.segregating_sites TreeSequence.sample_count_stat diff --git a/docs/stats.md b/docs/stats.md index 39257e1017..72aa5d615b 100644 --- a/docs/stats.md +++ b/docs/stats.md @@ -71,6 +71,7 @@ appears beside the listed method. * Multi-way * {meth}`~TreeSequence.divergence` * {meth}`~TreeSequence.genetic_relatedness` + {meth}`~TreeSequence.genetic_relatedness_weighted` * {meth}`~TreeSequence.f4` {meth}`~TreeSequence.f3` {meth}`~TreeSequence.f2` @@ -593,6 +594,12 @@ and boolean expressions (e.g., {math}`(x > 0)`) are interpreted as 0/1. where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number of samples. +`genetic_relatedness_weighted` +: {math}`f(w_i, w_j, x_i, x_j) = \frac{1}{2}(x_i - w_i m) (x_j - w_j m)`, + + where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number + of samples, and {math}`w_j = \sum_{k=1}^n W_kj` is the sum of the weights in the {math}`j`th column of the weight matrix. + `Y2` : {math}`f(x_1, x_2) = \frac{x_1 (n_2 - x_2) (n_2 - x_2 - 1)}{n_1 n_2 (n_2 - 1)}` diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 7725931b73..a06a690483 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -2101,6 +2101,203 @@ def test_match_K_c0(self): self.assertArrayAlmostEqual(A, B) +############################################ +# Genetic relatedness weighted +############################################ + + +def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): + n = len(sample_sets) + indexes = [ + (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2) + ] + if windows is None: + if mode == "node": + n_nodes = ts.num_nodes + K = np.zeros((n_nodes, n, n)) + out = ts.genetic_relatedness( + sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + ) + for node in range(n_nodes): + this_K = np.zeros((n, n)) + this_K[np.triu_indices(n)] = out[node, :] + this_K = this_K + np.triu(this_K, 1).transpose() + K[node, :, :] = this_K + else: + K = np.zeros((n, n)) + K[np.triu_indices(n)] = ts.genetic_relatedness( + sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + ) + K = K + np.triu(K, 1).transpose() + else: + windows = ts.parse_windows(windows) + n_windows = len(windows) - 1 + out = ts.genetic_relatedness( + sample_sets, + indexes, + mode=mode, + windows=windows, + proportion=False, + span_normalise=True, + ) + if mode == "node": + n_nodes = ts.num_nodes + K = np.zeros((n_windows, n_nodes, n, n)) + for win in range(n_windows): + for node in range(n_nodes): + K_this = np.zeros((n, n)) + K_this[np.triu_indices(n)] = out[win, node, :] + K_this = K_this + np.triu(K_this, 1).transpose() + K[win, node, :, :] = K_this + else: + K = np.zeros((n_windows, n, n)) + for win in range(n_windows): + K_this = np.zeros((n, n)) + K_this[np.triu_indices(n)] = out[win, :] + K_this = K_this + np.triu(K_this, 1).transpose() + K[win, :, :] = K_this + return K + + +def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"): + W_mean = W.mean(axis=0) + W = W - W_mean + sample_sets = [[u] for u in ts.samples()] + K = genetic_relatedness_matrix(ts, sample_sets, windows, mode) + n_indexes = len(indexes) + n_nodes = ts.num_nodes + if windows is None: + if mode == "node": + out = np.zeros((n_nodes, n_indexes)) + else: + out = np.zeros(n_indexes) + else: + windows = ts.parse_windows(windows) + n_windows = len(windows) - 1 + if mode == "node": + out = np.zeros((n_windows, n_nodes, n_indexes)) + else: + out = np.zeros((n_windows, n_indexes)) + for pair in range(n_indexes): + i1 = indexes[pair][0] + i2 = indexes[pair][1] + if windows is None: + if mode == "node": + for node in range(n_nodes): + this_K = K[node, :, :] + out[node, pair] = W[:, i1] @ this_K @ W[:, i2] + else: + out[pair] = W[:, i1] @ K @ W[:, i2] + else: + for win in range(n_windows): + if mode == "node": + for node in range(n_nodes): + this_K = K[win, node, :, :] + out[win, node, pair] = W[:, i1] @ this_K @ W[:, i2] + else: + this_K = K[win, :, :] + out[win, pair] = W[:, i1] @ this_K @ W[:, i2] + return out + + +def example_index_pairs(weights): + assert weights.shape[1] >= 2 + yield [(0, 1)] + yield [(1, 0), (0, 1)] + if weights.shape[1] > 2: + yield [(0, 1), (1, 2), (0, 2)] + + +class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin): + + # Derived classes define this to get a specific stats mode. + mode = None + + def verify_definition( + self, ts, W, indexes, windows, summary_func, ts_method, definition + ): + + # Determine output_dim of the function + M = len(indexes) + + sigma1 = ts.general_stat( + W, summary_func, M, windows, mode=self.mode, span_normalise=True + ) + sigma2 = general_stat( + ts, W, summary_func, windows, mode=self.mode, span_normalise=True + ) + + sigma3 = ts_method( + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + sigma4 = definition( + ts, + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + assert sigma1.shape == sigma2.shape + assert sigma1.shape == sigma3.shape + assert sigma1.shape == sigma4.shape + self.assertArrayAlmostEqual(sigma1, sigma2) + self.assertArrayAlmostEqual(sigma1, sigma3) + self.assertArrayAlmostEqual(sigma1, sigma4) + + def verify(self, ts): + for W, windows in subset_combos( + self.example_weights(ts, min_size=2), example_windows(ts), p=0.1 + ): + for indexes in example_index_pairs(W): + self.verify_weighted_stat(ts, W, indexes, windows) + + def verify_weighted_stat(self, ts, W, indexes, windows): + W_mean = W.mean(axis=0) + W = W - W_mean + W_sum = W.sum(axis=0) + n = W.shape[0] + + def f(x): + mx = np.sum(x) / n + return np.array( + [ + (x[i] - W_sum[i] * mx) * (x[j] - W_sum[j] * mx) / 2 + for i, j in indexes + ] + ) + + self.verify_definition( + ts, + W, + indexes, + windows, + f, + ts.genetic_relatedness_weighted, + genetic_relatedness_weighted, + ) + + +class TestBranchGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, TopologyExamplesMixin +): + mode = "branch" + + +class TestNodeGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, TopologyExamplesMixin +): + mode = "node" + + +class TestSiteGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, MutatedTopologyExamplesMixin +): + mode = "site" + + ############################################ # Fst ############################################ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 356f1fcfe4..67b4db68aa 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8000,11 +8000,14 @@ def genetic_relatedness_weighted( then the k-th column of output will be :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the - {meth}`.genetic_relatedness` between sample i and sample j. - - :param numpy.ndarray W: An array of values with one row for each sample and one - column for each set of weights. - :param list indexes: A list of 2-tuples, or None. + :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample + a and sample b, summing over all pairs of samples in the tree sequence. + + :param numpy.ndarray W: An array of values with one row for each sample node and + one column for each set of weights. + :param list indexes: A list of 2-tuples, or None (default). Note that if + indexes = None, then W must have exactly two columns and this is equivalent + to indexes = [(0,1)]. :param list windows: An increasing list of breakpoints between the windows to compute the statistic in. :param str mode: A string giving the "type" of the statistic to be computed