diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index f765c75c9f..6488fa7b24 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -460,7 +460,7 @@ def check_relatedness_vector( return R -class TestExamples: +class TestRelatednessVector: def test_bad_weights(self): n = 5 @@ -737,3 +737,147 @@ def test_disconnected_non_sample_topology(self, centre): ts2, internal_checks=True, centre=centre, do_nodes=False ) np.testing.assert_array_almost_equal(D1, D2) + + +def pca(ts, windows, centre): + drop_dimension = windows is None + if drop_dimension: + windows = [0, ts.sequence_length] + Sigma = relatedness_matrix(ts=ts, windows=windows, centre=centre) + U, S, _ = np.linalg.svd(Sigma, hermitian=True) + if drop_dimension: + U = U[0] + S = S[0] + return U, S + + +def allclose_up_to_sign(x, y, **kwargs): + # check if two vectors are the same up to sign + x_const = np.isclose(np.std(x), 0) + y_const = np.isclose(np.std(y), 0) + if x_const or y_const: + if np.allclose(x, 0): + r = 1.0 + else: + r = np.mean(x / y) + else: + r = np.sign(np.corrcoef(x, y)[0, 1]) + return np.allclose(x, r * y, **kwargs) + + +def assert_pcs_equal(U, D, U_full, D_full, rtol=1e-05, atol=1e-08): + # check that the PCs in U, D occur in U_full, D_full + # accounting for sign and ordering + assert len(D) <= len(D_full) + assert U.shape[0] == U_full.shape[0] + assert U.shape[1] == len(D) + for k in range(len(D)): + u = U[:, k] + d = D[k] + (ii,) = np.where(np.isclose(D_full, d, rtol=rtol, atol=atol)) + assert len(ii) > 0, f"{k}th singular value {d} not found in {D_full}." + found_it = False + for i in ii: + if allclose_up_to_sign(u, U_full[:, i], rtol=rtol, atol=atol): + found_it = True + break + assert found_it, f"{k}th singular vector {u} not found in {U_full}." + + +class TestPCA: + + def verify_pca(self, ts, num_windows, n_components, centre): + if num_windows == 0: + windows = None + elif num_windows % 2 == 0: + windows = np.linspace( + 0.2 * ts.sequence_length, 0.8 * ts.sequence_length, num_windows + 1 + ) + else: + windows = np.linspace(0, ts.sequence_length, num_windows + 1) + ts_U, ts_D = ts.pca( + windows=windows, n_components=n_components, centre=centre, random_seed=123 + ) + num_rows = ts.num_samples + if windows is None: + assert ts_U.shape == (num_rows, n_components) + assert ts_D.shape == (n_components,) + else: + assert ts_U.shape == (num_windows, num_rows, n_components) + assert ts_D.shape == (num_windows, n_components) + U, D = pca(ts=ts, windows=windows, centre=centre) + if windows is None: + np.testing.assert_allclose(ts_D, D[:n_components], atol=1e-8) + assert_pcs_equal(ts_U, ts_D, U, D) + else: + for w in range(num_windows): + np.testing.assert_allclose(ts_D[w], D[w, :n_components], atol=1e-8) + assert_pcs_equal(ts_U[w], ts_D[w], U[w], D[w]) + + def test_bad_windows(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_w in ([], [1]): + with pytest.raises(ValueError, match="Number of windows"): + ts.pca(n_components=2, windows=bad_w) + for bad_w in ([1, 0], [-3, 10]): + with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_WINDOWS"): + ts.pca(n_components=2, windows=bad_w) + + def test_bad_num_components(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + with pytest.raises(ValueError, match="Number of components"): + ts.pca(n_components=ts.num_samples + 1) + with pytest.raises(ValueError, match="Number of components"): + ts.pca(n_components=4, samples=[0, 1, 2]) + with pytest.raises(ValueError, match="Number of components"): + ts.pca(n_components=4, individuals=[0, 1]) + + def test_indivs_and_samples(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + with pytest.raises(ValueError, match="Samples and individuals"): + ts.pca(n_components=2, samples=[0, 1, 2, 3], individuals=[0, 1, 2]) + + def test_modes(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_mode in ("site", "node"): + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE" + ): + ts.pca(n_components=2, mode=bad_mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("centre", (True, False)) + @pytest.mark.parametrize("num_windows", (0, 1, 2, 3)) + @pytest.mark.parametrize("n_components", (1, 3)) + def test_simple_sims(self, n, centre, num_windows, n_components): + ploidy = 1 + nc = min(n_components, n * ploidy) + ts = msprime.sim_ancestry( + n, + ploidy=ploidy, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=12345, + ) + self.verify_pca(ts, num_windows=num_windows, n_components=nc, centre=centre) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 7d3bbca7b9..6064e7556f 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8592,6 +8592,216 @@ def genetic_relatedness_vector( ) return out + def pca( + self, + num_components: int, + windows: list = None, + samples: np.ndarray = None, + individuals: np.ndarray = None, + mode: str = "branch", + centre: bool = True, + iterated_power: int = 5, + num_oversamples: int = 10, + random_seed: int = None, + range_sketch: np.ndarray = None, + ) -> (np.ndarray, np.ndarray, np.ndarray): + """ + Run randomized singular value decomposition (rSVD) to obtain principal + components. + API partially adopted from `scikit-learn`'s + `sklearn.decomposition.PCA.html` + + By default, performs PCA for the samples, so output has one coordinate + for each sample), but alternatively either a list of sample IDs or a + list of individual IDs can be provided (but not both). + + TODO: say exactly what is returned (and relationship to + :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>`). + + TODO: say what algorithms are used. + + :param int num_components: Number of principal components. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param np.ndarray samples: Samples to perform PCA with. + :param np.ndarray individuals: Individuals to perform PCA with. Cannot specify + both `samples` and `individuals`. + :param str mode: A string giving the "type" of relatedness to be computed + (defaults to "branch"; see + :meth:`genetic_relatedness_vector + <.TreeSequence.genetic_relatedness_vector>`) + :param bool centre: Centre the genetic relatedness matrix. + :param int iterated_power: Number of power iteration of range finder. + :param int num_oversamples: Number of additional test vectors. + :param int random_seed: The random seed. If this is None, a random seed will + be automatically generated. Valid random seeds must be between 1 and + :math:`2^32 − 1`. + :param np.ndarray range_sketch: Sketch matrix for each window. Default is None. + :return: A tuple (U, D, Q) of ndarrays, with the principal component loadings in U + and the principal values in D. Q is the range sketch array for each window. + """ + + if samples is None and individuals is None: + samples = self.samples() + + if samples is not None and individuals is not None: + raise ValueError("Samples and individuals cannot be used at the same time") + elif samples is not None: + output_type = "node" + dim = len(samples) + else: + assert individuals is not None + output_type = "individual" + dim = len(individuals) + + if range_sketch is not None: + if windows is not None: + assert range_sketch.shape[0] == len(windows) - 1 + elif windows is None: + range_sketch = np.expand_dims(range_sketch, 0) + + if num_components > dim: + raise ValueError( + "Number of components must be less than or equal to " + "the number of samples (or individuals, if specified)." + ) + + random_state = np.random.default_rng(random_seed) + + def _rand_pow_range_finder( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> np.ndarray: + """ + Algorithm 9 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0, "num_vectors should be larger than rank" + if range_sketch is None: + test_vectors = rng.normal(size=(operator_dim, num_vectors)) + Q = test_vectors + else: + Q = range_sketch + for _ in range(depth): + Q = np.linalg.qr(Q).Q + Q = operator(Q) + Q = np.linalg.qr(Q).Q + return Q[:, :rank] + + def _rand_svd( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> (np.ndarray, np.ndarray, np.ndarray, float): + """ + Algorithm 8 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0 + Q = _rand_pow_range_finder( + operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch + ) + C = operator(Q).T + U_hat, D, V = np.linalg.svd(C, full_matrices=False) + U = Q @ U_hat + + error_factor = np.power( + 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), + 1 / (2 * depth + 1)) + error_bound = D[-1] * (2 + error_factor) + return U[:, :rank], D[:rank], V[:rank], Q, error_bound + + def _genetic_relatedness_vector_individual( + arr: np.ndarray, + centre: bool = True, + windows=None, + ) -> np.ndarray: + ij = np.vstack( + [ + [n, k] + for k, i in enumerate(individuals) + for n in self.individual(i).nodes + ] + ) + samples, sample_individuals = ( + ij[:, 0], + ij[:, 1], + ) # sample node index, individual of those nodes + x = ( + arr - arr.mean(axis=0) if centre else arr + ) # centering within index in rows + x = self.genetic_relatedness_vector( + W=x[sample_individuals], + windows=windows, + mode=mode, + centre=False, + nodes=samples, + )[0] + + def bincount_fn(w): + return np.bincount(sample_individuals, w) + + x = np.apply_along_axis(bincount_fn, axis=0, arr=x) + x = x - x.mean(axis=0) if centre else x # centering within index in cols + + return x + + def _genetic_relatedness_vector_node( + arr: np.ndarray, + centre: bool = True, + windows=None, + ) -> np.ndarray: + x = arr - arr.mean(axis=0) if centre else arr + x = self.genetic_relatedness_vector( + W=x, windows=windows, mode=mode, centre=False, nodes=samples + )[0] + x = x - x.mean(axis=0) if centre else x + + return x + + drop_windows = windows is None + windows = self.parse_windows(windows) + num_windows = len(windows) - 1 + if num_windows < 1: + raise ValueError("Number of windows must be at least 1.") + + U = np.empty((num_windows, dim, num_components)) + D = np.empty((num_windows, num_components)) + Q = np.empty((num_windows, dim, num_components + num_oversamples)) + E = np.empty(num_windows) + for i in range(num_windows): + this_window = windows[i : i + 2] + _f = ( + _genetic_relatedness_vector_node + if output_type == "node" + else _genetic_relatedness_vector_individual + ) + + def _G(x): + return _f(x, centre=centre, windows=this_window) # NOQA: B023 + + U[i], D[i], _, Q[i], E[i] = _rand_svd( + operator=_G, + operator_dim=dim, + rank=num_components, + depth=iterated_power, + num_vectors=num_components + num_oversamples, + rng=random_state, + range_sketch=None if range_sketch is None else range_sketch[i], + ) + + if drop_windows: + U, D, Q = U[0], D[0], Q[0] + + return U, D, Q, E + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W``