Skip to content

Commit

Permalink
Full pass at tests for genetic_relatedness_weighted
Browse files Browse the repository at this point in the history
  • Loading branch information
brieuclehmann committed Jul 7, 2023
1 parent eff37eb commit 81aca50
Showing 1 changed file with 108 additions and 12 deletions.
120 changes: 108 additions & 12 deletions python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,28 +2141,100 @@ def test_match_K_c0(self):
############################################
# Genetic relatedness weighted
############################################
# still need to implement multiple index pairs and multiple windows


def genetic_relatedness_matrix(ts, sample_sets, mode, windows=None):
def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"):
n = len(sample_sets)
indexes = [
(n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
]
K = np.zeros((n, n))
K[np.triu_indices(n)] = ts.genetic_relatedness(
sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
)
K = K + np.triu(K, 1).transpose()
if windows is None:
if mode == "node":
n_nodes = ts.num_nodes
K = np.zeros((n_nodes, n, n))
out = ts.genetic_relatedness(
sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
)
for node in range(n_nodes):
this_K = np.zeros((n, n))
this_K[np.triu_indices(n)] = out[node, :]
this_K = this_K + np.triu(this_K, 1).transpose()
K[node, :, :] = this_K
else:
K = np.zeros((n, n))
K[np.triu_indices(n)] = ts.genetic_relatedness(
sample_sets, indexes, mode=mode, proportion=False, span_normalise=True
)
K = K + np.triu(K, 1).transpose()
else:
windows = ts.parse_windows(windows)
n_windows = len(windows) - 1
out = ts.genetic_relatedness(
sample_sets,
indexes,
mode=mode,
windows=windows,
proportion=False,
span_normalise=True,
)
if mode == "node":
n_nodes = ts.num_nodes
K = np.zeros((n_windows, n_nodes, n, n))
for win in range(n_windows):
for node in range(n_nodes):
K_this = np.zeros((n, n))
K_this[np.triu_indices(n)] = out[win, node, :]
K_this = K_this + np.triu(K_this, 1).transpose()
K[win, node, :, :] = K_this
else:
K = np.zeros((n_windows, n, n))
for win in range(n_windows):
K_this = np.zeros((n, n))
K_this[np.triu_indices(n)] = out[win, :]
K_this = K_this + np.triu(K_this, 1).transpose()
K[win, :, :] = K_this
return K


def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"):
W_mean = W.mean(axis=0)
W = W - W_mean
sample_sets = [[u] for u in ts.samples()]
K = genetic_relatedness_matrix(ts, sample_sets, mode)
i1 = indexes[0]
i2 = indexes[1]
return W[:, i1] @ K @ W[:, i2]
K = genetic_relatedness_matrix(ts, sample_sets, windows, mode)
n_indexes = len(indexes)
n_nodes = ts.num_nodes
if windows is None:
if mode == "node":
out = np.zeros((n_nodes, n_indexes))
else:
out = np.zeros(n_indexes)
else:
windows = ts.parse_windows(windows)
n_windows = len(windows) - 1
if mode == "node":
out = np.zeros((n_windows, n_nodes, n_indexes))
else:
out = np.zeros((n_windows, n_indexes))
for pair in range(n_indexes):
i1 = indexes[pair][0]
i2 = indexes[pair][1]
if windows is None:
if mode == "node":
for node in range(n_nodes):
this_K = K[node, :, :]
out[node, pair] = W[:, i1] @ this_K @ W[:, i2]
else:
out[pair] = W[:, i1] @ K @ W[:, i2]
else:
for win in range(n_windows):
if mode == "node":
for node in range(n_nodes):
this_K = K[win, node, :, :]
out[win, node, pair] = W[:, i1] @ this_K @ W[:, i2]
else:
this_K = K[win, :, :]
out[win, pair] = W[:, i1] @ this_K @ W[:, i2]
return out


def example_index_pairs(weights):
Expand All @@ -2174,6 +2246,8 @@ def example_index_pairs(weights):


class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin):

# Derived classes define this to get a specific stats mode.
mode = None

def verify_definition(
Expand Down Expand Up @@ -2212,8 +2286,12 @@ def verify(self, ts):
def verify_weighted_stat(self, ts, W, indexes, windows):
n = W.shape[0]

# THIS IS WRONG, COPIED FROM GENETIC_RELATEDNESS
def f(x):
return (x**2) / (2 * (n - 1) * (n - 1))
mx = np.sum(x) / n
return np.array(
[(x[i] - n[i] * mx) * (x[j] - n[j] * mx) / 2 for i, j in indexes]
)

self.verify_definition(
ts,
Expand All @@ -2226,6 +2304,24 @@ def f(x):
)


class TestBranchGeneticRelatednessWeighted(
TestGeneticRelatednessWeighted, TopologyExamplesMixin
):
mode = "branch"


class TestNodeGeneticRelatednessWeighted(
TestGeneticRelatednessWeighted, TopologyExamplesMixin
):
mode = "node"


class TestSiteGeneticRelatednessWeighted(
TestGeneticRelatednessWeighted, MutatedTopologyExamplesMixin
):
mode = "site"


############################################
# Fst
############################################
Expand Down

0 comments on commit 81aca50

Please sign in to comment.