-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
randomized svd draft #3008
base: main
Are you sure you want to change the base?
randomized svd draft #3008
Changes from 2 commits
1f45245
e408ab3
a176132
8c662c8
aa13613
5bf405a
6e415e5
45ac61e
2cdb9dd
1f194aa
fdc5842
c0a2854
667d60f
36c10e9
0ac88b8
a6fa8ab
791c7da
2f4ce2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ | |
from typing import NamedTuple | ||
|
||
import numpy as np | ||
import scipy.sparse | ||
|
||
import _tskit | ||
import tskit | ||
|
@@ -8592,6 +8593,133 @@ def genetic_relatedness_vector( | |
) | ||
return out | ||
|
||
def pca( | ||
self, | ||
n_components: int = 10, | ||
iterated_power: int = 3, | ||
n_oversamples: int = 10, | ||
samples: np.ndarray = None, | ||
individuals: np.ndarray = None, | ||
centre: bool = True, | ||
windows = None, | ||
random_state: np.random.Generator = None, | ||
) -> (np.ndarray, np.ndarray): | ||
""" | ||
Run randomized singular value decomposition (rSVD) to obtain principal components. | ||
API partially adopted from `scikit-learn`: | ||
https://scikit-learn.org/dev/modules/generated/sklearn.decomposition.PCA.html | ||
|
||
:param int n_components: Number of principal components | ||
:param int iterated_power: Number of power iteration of range finder | ||
:param int n_oversamples: Number of additional test vectors | ||
:param np.ndarray samples: Samples to perform PCA | ||
:param np.ndarray individuals: Individuals to perform PCA | ||
:param bool centre: Centre the genetic relatedness matrix | ||
:param windows: ??? | ||
:param np.random.Generator random_state: Random number generator | ||
""" | ||
|
||
def _rand_pow_range_finder( | ||
operator: Callable, | ||
operator_dim: int, | ||
rank: int, | ||
depth: int, | ||
num_vectors: int, | ||
rng: np.random.Generator, | ||
) -> np.ndarray: | ||
""" | ||
Algorithm 9 in https://arxiv.org/pdf/2002.01387 | ||
""" | ||
assert num_vectors >= rank > 0 | ||
test_vectors = rng.normal(size=(operator_dim, num_vectors)) | ||
Q = test_vectors | ||
for i in range(depth): | ||
Q = np.linalg.qr(Q).Q | ||
Q = operator(Q) | ||
Q = np.linalg.qr(Q).Q | ||
return Q[:, :rank] | ||
|
||
def _rand_svd( | ||
operator: Callable, | ||
operator_dim: int, | ||
rank: int, | ||
depth: int, | ||
num_vectors: int, | ||
rng: np.random.Generator, | ||
) -> (np.ndarray, np.ndarray, np.ndarray): | ||
""" | ||
Algorithm 8 in https://arxiv.org/pdf/2002.01387 | ||
""" | ||
assert num_vectors >= rank > 0 | ||
Q = _rand_pow_range_finder( | ||
operator, | ||
operator_dim, | ||
num_vectors, | ||
depth, | ||
num_vectors, | ||
rng | ||
) | ||
C = operator(Q).T | ||
U_hat, D, V = np.linalg.svd(C, full_matrices=False) | ||
U = Q @ U_hat | ||
return U[:,:rank], D[:rank], V[:rank] | ||
|
||
def _genetic_relatedness_vector( | ||
arr: np.ndarray, | ||
rows: np.ndarray, | ||
cols: np.ndarray, | ||
centre: bool = True, | ||
windows = None, | ||
) -> np.ndarray: | ||
""" | ||
Wrapper around `tskit.TreeSequence.genetic_relatedness_vector` to support centering in respect to individuals. | ||
Multiplies an array to the genetic relatedness matrix of :class:`tskit.TreeSequence`. | ||
|
||
:param numpy.ndarray arr: The array to multiply. Either a vector or a matrix. | ||
:param numpy.ndarray rows: Index of rows of the genetic relatedness matrix to be selected. | ||
:param numpy.ndarray cols: Index of cols of the genetic relatedness matrix to be selected. The size should match the row length of `arr`. | ||
:param bool centre: Centre the genetic relatedness matrix. Centering happens respect to the `rows` and `cols`. | ||
:param windows: An increasing list of breakpoints between the windows to compute the genetic relatedness matrix in. | ||
:return: An array that is the matrix-array product of the genetic relatedness matrix and the array. | ||
:rtype: `np.ndarray` | ||
""" | ||
|
||
assert cols.size == arr.shape[0], "Dimension mismatch" | ||
ij = np.vstack([[n,k] for k, i in enumerate(individuals) for n in self.individual(i).nodes]) | ||
samples, sample_individuals = ij[:,0], ij[:,1] # sample node index, individual of those nodes | ||
x = arr - arr.mean(axis=0) if centre else arr # centering within index in rows | ||
x = self.genetic_relatedness_vector(W=x[sample_individuals], windows=windows, mode="branch", centre=False, nodes=samples) | ||
bincount_fn = lambda w: np.bincount(sample_individuals, w) | ||
x = np.apply_along_axis(bincount_fn, axis=0, arr=x) # I think it should be axis=1, but axis=0 gives the correct values why? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The matvec is sometimes GRM * matrix, so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment I left after # looks like mostly a convention issue in that function. When There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, that seems confusing but maybe makes sense after all? |
||
x = x - x.mean(axis=0) if centre else x # centering within index in cols | ||
|
||
return x | ||
|
||
|
||
if random_state is None: random_state = np.random.default_rng() | ||
if samples is None and individuals is None: samples = self.samples() | ||
|
||
if samples is not None and individuals is not None: | ||
raise ValueError("samples and individuals cannot be used at the same time") | ||
elif samples is not None: | ||
_G = lambda x: self.genetic_relatedness_vector(x, windows=windows, mode="branch", centre=centre, nodes=samples) | ||
dim = samples.size | ||
elif individuals is not None: | ||
_G = lambda x: _genetic_relatedness_vector(x, individuals, individuals, centre=centre, windows=windows) | ||
dim = individuals.size | ||
|
||
U, D, _ = _rand_svd( | ||
petrelharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
operator=_G, | ||
operator_dim=dim, | ||
rank=n_components, | ||
depth=iterated_power, | ||
num_vectors=n_components+n_oversamples, | ||
rng=random_state | ||
) | ||
|
||
return U, D | ||
|
||
|
||
def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): | ||
""" | ||
Computes the mean squared covariances between each of the columns of ``W`` | ||
|
@@ -10171,3 +10299,4 @@ def write_ms( | |
) | ||
else: | ||
print(file=output) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually we just pass in a
seed
, any objections to doing that, instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the option from
random_state
torandom_seed
following msprime.