Skip to content

Commit

Permalink
documentation for genetic_relatedness_weighted
Browse files Browse the repository at this point in the history
First pass at genetic_relatedness_weighted tests

Full pass at tests for genetic_relatedness_weighted

Update python/tskit/trees.py

Co-authored-by: Peter Ralph <[email protected]>

Add summary func to genetic_relatedness_weighted tests

Fix summary function definition in docs
  • Loading branch information
brieuclehmann authored and jeromekelleher committed Jul 13, 2023
1 parent d5d24e1 commit 6c144d8
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ Single site
TreeSequence.Fst
TreeSequence.genealogical_nearest_neighbours
TreeSequence.genetic_relatedness
TreeSequence.genetic_relatedness_weighted
TreeSequence.general_stat
TreeSequence.segregating_sites
TreeSequence.sample_count_stat
Expand Down
7 changes: 7 additions & 0 deletions docs/stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ appears beside the listed method.
* Multi-way
* {meth}`~TreeSequence.divergence`
* {meth}`~TreeSequence.genetic_relatedness`
{meth}`~TreeSequence.genetic_relatedness_weighted`
* {meth}`~TreeSequence.f4`
{meth}`~TreeSequence.f3`
{meth}`~TreeSequence.f2`
Expand Down Expand Up @@ -593,6 +594,12 @@ and boolean expressions (e.g., {math}`(x > 0)`) are interpreted as 0/1.
where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number
of samples.

`genetic_relatedness_weighted`
: {math}`f(w_i, w_j, x_i, x_j) = \frac{1}{2}(x_i - w_i m) (x_j - w_j m)`,

where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number
of samples, and {math}`w_j = \sum_{k=1}^n W_kj` is the sum of the weights in the {math}`j`th column of the weight matrix.

`Y2`
: {math}`f(x_1, x_2) = \frac{x_1 (n_2 - x_2) (n_2 - x_2 - 1)}{n_1 n_2 (n_2 - 1)}`

Expand Down
199 changes: 198 additions & 1 deletion python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2018-2022 Tskit Developers
# Copyright (c) 2018-2023 Tskit Developers
# Copyright (C) 2016 University of Oxford
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -2101,6 +2101,203 @@ def test_match_K_c0(self):
self.assertArrayAlmostEqual(A, B)


############################################
# Genetic relatedness weighted
############################################


def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"):
n = len(sample_sets)
indexes = [
(n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
]
if windows is None:
if mode == "node":
n_nodes = ts.num_nodes
K = np.zeros((n_nodes, n, n))
out = ts.genetic_relatedness(
sample_sets, indexes, mode=mode, proportion=False, span_normalise=True
)
for node in range(n_nodes):
this_K = np.zeros((n, n))
this_K[np.triu_indices(n)] = out[node, :]
this_K = this_K + np.triu(this_K, 1).transpose()
K[node, :, :] = this_K
else:
K = np.zeros((n, n))
K[np.triu_indices(n)] = ts.genetic_relatedness(
sample_sets, indexes, mode=mode, proportion=False, span_normalise=True
)
K = K + np.triu(K, 1).transpose()
else:
windows = ts.parse_windows(windows)
n_windows = len(windows) - 1
out = ts.genetic_relatedness(
sample_sets,
indexes,
mode=mode,
windows=windows,
proportion=False,
span_normalise=True,
)
if mode == "node":
n_nodes = ts.num_nodes
K = np.zeros((n_windows, n_nodes, n, n))
for win in range(n_windows):
for node in range(n_nodes):
K_this = np.zeros((n, n))
K_this[np.triu_indices(n)] = out[win, node, :]
K_this = K_this + np.triu(K_this, 1).transpose()
K[win, node, :, :] = K_this
else:
K = np.zeros((n_windows, n, n))
for win in range(n_windows):
K_this = np.zeros((n, n))
K_this[np.triu_indices(n)] = out[win, :]
K_this = K_this + np.triu(K_this, 1).transpose()
K[win, :, :] = K_this
return K


def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"):
W_mean = W.mean(axis=0)
W = W - W_mean
sample_sets = [[u] for u in ts.samples()]
K = genetic_relatedness_matrix(ts, sample_sets, windows, mode)
n_indexes = len(indexes)
n_nodes = ts.num_nodes
if windows is None:
if mode == "node":
out = np.zeros((n_nodes, n_indexes))
else:
out = np.zeros(n_indexes)
else:
windows = ts.parse_windows(windows)
n_windows = len(windows) - 1
if mode == "node":
out = np.zeros((n_windows, n_nodes, n_indexes))
else:
out = np.zeros((n_windows, n_indexes))
for pair in range(n_indexes):
i1 = indexes[pair][0]
i2 = indexes[pair][1]
if windows is None:
if mode == "node":
for node in range(n_nodes):
this_K = K[node, :, :]
out[node, pair] = W[:, i1] @ this_K @ W[:, i2]
else:
out[pair] = W[:, i1] @ K @ W[:, i2]
else:
for win in range(n_windows):
if mode == "node":
for node in range(n_nodes):
this_K = K[win, node, :, :]
out[win, node, pair] = W[:, i1] @ this_K @ W[:, i2]
else:
this_K = K[win, :, :]
out[win, pair] = W[:, i1] @ this_K @ W[:, i2]
return out


def example_index_pairs(weights):
assert weights.shape[1] >= 2
yield [(0, 1)]
yield [(1, 0), (0, 1)]
if weights.shape[1] > 2:
yield [(0, 1), (1, 2), (0, 2)]


class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin):

# Derived classes define this to get a specific stats mode.
mode = None

def verify_definition(
self, ts, W, indexes, windows, summary_func, ts_method, definition
):

# Determine output_dim of the function
M = len(indexes)

sigma1 = ts.general_stat(
W, summary_func, M, windows, mode=self.mode, span_normalise=True
)
sigma2 = general_stat(
ts, W, summary_func, windows, mode=self.mode, span_normalise=True
)

sigma3 = ts_method(
W,
indexes=indexes,
windows=windows,
mode=self.mode,
)
sigma4 = definition(
ts,
W,
indexes=indexes,
windows=windows,
mode=self.mode,
)
assert sigma1.shape == sigma2.shape
assert sigma1.shape == sigma3.shape
assert sigma1.shape == sigma4.shape
self.assertArrayAlmostEqual(sigma1, sigma2)
self.assertArrayAlmostEqual(sigma1, sigma3)
self.assertArrayAlmostEqual(sigma1, sigma4)

def verify(self, ts):
for W, windows in subset_combos(
self.example_weights(ts, min_size=2), example_windows(ts), p=0.1
):
for indexes in example_index_pairs(W):
self.verify_weighted_stat(ts, W, indexes, windows)

def verify_weighted_stat(self, ts, W, indexes, windows):
W_mean = W.mean(axis=0)
W = W - W_mean
W_sum = W.sum(axis=0)
n = W.shape[0]

def f(x):
mx = np.sum(x) / n
return np.array(
[
(x[i] - W_sum[i] * mx) * (x[j] - W_sum[j] * mx) / 2
for i, j in indexes
]
)

self.verify_definition(
ts,
W,
indexes,
windows,
f,
ts.genetic_relatedness_weighted,
genetic_relatedness_weighted,
)


class TestBranchGeneticRelatednessWeighted(
TestGeneticRelatednessWeighted, TopologyExamplesMixin
):
mode = "branch"


class TestNodeGeneticRelatednessWeighted(
TestGeneticRelatednessWeighted, TopologyExamplesMixin
):
mode = "node"


class TestSiteGeneticRelatednessWeighted(
TestGeneticRelatednessWeighted, MutatedTopologyExamplesMixin
):
mode = "site"


############################################
# Fst
############################################
Expand Down
13 changes: 8 additions & 5 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -8000,11 +8000,14 @@ def genetic_relatedness_weighted(
then the k-th column of output will be
:math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`,
where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the
{meth}`.genetic_relatedness` between sample i and sample j.
:param numpy.ndarray W: An array of values with one row for each sample and one
column for each set of weights.
:param list indexes: A list of 2-tuples, or None.
:meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample
a and sample b, summing over all pairs of samples in the tree sequence.
:param numpy.ndarray W: An array of values with one row for each sample node and
one column for each set of weights.
:param list indexes: A list of 2-tuples, or None (default). Note that if
indexes = None, then W must have exactly two columns and this is equivalent
to indexes = [(0,1)].
:param list windows: An increasing list of breakpoints between the windows
to compute the statistic in.
:param str mode: A string giving the "type" of the statistic to be computed
Expand Down

0 comments on commit 6c144d8

Please sign in to comment.