Skip to content

Commit

Permalink
add centre option
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Sep 19, 2024
1 parent 39c4c14 commit 86b87d7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 41 deletions.
11 changes: 8 additions & 3 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9651,7 +9651,8 @@ TreeSequence_weighted_stat_vector_method(
TreeSequence *self, PyObject *args, PyObject *kwds, weighted_vector_method *method)
{
PyObject *ret = NULL;
static char *kwlist[] = { "weights", "windows", "mode", "span_normalise", NULL };
static char *kwlist[]
= { "weights", "windows", "mode", "span_normalise", "centre", NULL };
PyObject *weights = NULL;
PyObject *windows = NULL;
PyArrayObject *weights_array = NULL;
Expand All @@ -9663,13 +9664,14 @@ TreeSequence_weighted_stat_vector_method(
tsk_size_t num_samples = tsk_treeseq_get_num_samples(self->tree_sequence);
char *mode = NULL;
int span_normalise = true;
int centre = true;
int err;

if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(
args, kwds, "OO|si", kwlist, &weights, &windows, &mode, &span_normalise)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|sii", kwlist, &weights, &windows,
&mode, &span_normalise, &centre)) {
goto out;
}
if (parse_stats_mode(mode, &options) != 0) {
Expand All @@ -9678,6 +9680,9 @@ TreeSequence_weighted_stat_vector_method(
if (span_normalise) {
options |= TSK_STAT_SPAN_NORMALISE;
}
if (!centre) {
options |= TSK_STAT_NONCENTRED;
}
if (parse_windows(windows, &windows_array, &num_windows) != 0) {
goto out;
}
Expand Down
97 changes: 62 additions & 35 deletions python/tests/test_relatedness_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def __init__(
sequence_length,
verbosity=0,
internal_checks=False,
centre=True,
):
self.sample_weights = np.asarray(sample_weights, dtype=np.float64)
num_weights = self.sample_weights.shape[1]
# virtual root is at num_nodes; virtual samples are beyond that
N = num_nodes + 1 + len(samples)
self.parent = np.full(N, -1, dtype=np.int32)
Expand All @@ -72,10 +74,14 @@ def __init__(
self.position = 0
self.virtual_root = num_nodes
self.x = np.zeros(N, dtype=np.float64)
self.w = np.zeros(N, dtype=np.float64)
self.v = np.zeros(N, dtype=np.float64)
self.w = np.zeros((N, num_weights), dtype=np.float64)
self.v = np.zeros((N, num_weights), dtype=np.float64)
self.verbosity = verbosity
self.internal_checks = internal_checks
self.centre = centre

if self.centre:
self.sample_weights -= np.mean(self.sample_weights, axis=0)

for j, u in enumerate(samples):
self.w[u] = self.sample_weights[j]
Expand Down Expand Up @@ -254,10 +260,12 @@ def run(self):
if self.verbosity > 1:
self.print_state()

out = np.zeros(len(self.samples))
out = np.zeros(self.sample_weights.shape)
for out_i in range(len(self.samples)):
i = out_i + self.virtual_root + 1
out[out_i] = self.v[i]
if self.centre:
out -= np.mean(out, axis=0)
return out


Expand All @@ -279,53 +287,64 @@ def relatedness_vector(ts, sample_weights, **kwargs):
return rv.run()


def relatedness_matrix(ts):
def relatedness_matrix(ts, centre):
Sigma = ts.genetic_relatedness(
sample_sets=[[i] for i in ts.samples()],
indexes=[(i, j) for i in range(ts.num_samples) for j in range(ts.num_samples)],
mode="branch",
span_normalise=False,
proportion=False,
centre=False,
centre=centre,
).reshape((ts.num_samples, ts.num_samples))
return Sigma


def verify_relatedness_vector(
ts, w, *, internal_checks=False, verbosity=0, centre=False
ts, w, *, internal_checks=False, verbosity=0, centre=True
):
w = np.round(len(w) * w)
R1 = relatedness_vector(
ts, sample_weights=w, internal_checks=internal_checks, verbosity=verbosity
ts,
sample_weights=w,
internal_checks=internal_checks,
verbosity=verbosity,
centre=centre,
)
Sigma = relatedness_matrix(ts)
Sigma = relatedness_matrix(ts, centre=centre)
R2 = Sigma.dot(w)
R3 = ts.genetic_relatedness_vector(w, mode="branch")
# R3 = ts.genetic_relatedness_vector(w, mode="branch", centre=centre)
if verbosity > 0:
print(ts.draw_text())
print("weights:", w)
print("here:", R1)
print("with ts:", R2)
print("Sigma:", Sigma)
np.testing.assert_allclose(R1, R2, atol=1e-14)
np.testing.assert_allclose(R1, R3)
# np.testing.assert_allclose(R1, R3)
return R1


def check_relatedness_vector(ts, n=5, *, internal_checks=False, verbosity=0, seed=123):
def check_relatedness_vector(
ts, n=5, *, internal_checks=False, verbosity=0, seed=123, centre=True
):
rng = np.random.default_rng(seed=seed)
for _ in range(n):
w = rng.normal(size=ts.num_samples)
for k in range(n):
w = rng.normal(size=ts.num_samples * k).reshape((ts.num_samples, k))
w = np.round(len(w) * w)
R = verify_relatedness_vector(
ts, w, internal_checks=internal_checks, verbosity=verbosity
ts,
w,
internal_checks=internal_checks,
verbosity=verbosity,
centre=centre,
)
return R


class TestExamples:
@pytest.mark.parametrize("n", [2, 3, 5])
@pytest.mark.parametrize("seed", range(1, 4))
def test_small_internal_checks(self, n, seed):
@pytest.mark.parametrize("centre", (True, False))
def test_small_internal_checks(self, n, seed, centre):
ts = msprime.sim_ancestry(
n,
ploidy=1,
Expand All @@ -334,11 +353,12 @@ def test_small_internal_checks(self, n, seed):
random_seed=seed,
)
assert ts.num_trees >= 2
check_relatedness_vector(ts, internal_checks=True)
check_relatedness_vector(ts, internal_checks=True, centre=centre)

@pytest.mark.parametrize("n", [2, 3, 5, 15])
@pytest.mark.parametrize("seed", range(1, 5))
def test_simple_sims(self, n, seed):
@pytest.mark.parametrize("centre", (True, False))
def test_simple_sims(self, n, seed, centre):
ts = msprime.sim_ancestry(
n,
ploidy=1,
Expand All @@ -348,25 +368,28 @@ def test_simple_sims(self, n, seed):
random_seed=seed,
)
assert ts.num_trees >= 2
check_relatedness_vector(ts)
check_relatedness_vector(ts, centre=centre)

@pytest.mark.parametrize("n", [2, 3, 5, 15])
def test_single_balanced_tree(self, n):
@pytest.mark.parametrize("centre", (True, False))
def test_single_balanced_tree(self, n, centre):
ts = tskit.Tree.generate_balanced(n).tree_sequence
check_relatedness_vector(ts, internal_checks=True, verbosity=1)
check_relatedness_vector(ts, internal_checks=True, centre=centre)

def test_internal_sample(self):
@pytest.mark.parametrize("centre", (True, False))
def test_internal_sample(self, centre):
tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables()
flags = tables.nodes.flags
flags[3] = 0
flags[5] = tskit.NODE_IS_SAMPLE
tables.nodes.flags = flags
ts = tables.tree_sequence()
check_relatedness_vector(ts, verbosity=0)
check_relatedness_vector(ts, centre=centre)

# @pytest.mark.skip()
@pytest.mark.parametrize("seed", range(1, 5))
def test_one_internal_sample_sims(self, seed):
@pytest.mark.parametrize("centre", (True, False))
def test_one_internal_sample_sims(self, seed, centre):
ts = msprime.sim_ancestry(
10,
ploidy=1,
Expand All @@ -382,9 +405,10 @@ def test_one_internal_sample_sims(self, seed):
t.sort()
t.build_index()
ts = t.tree_sequence()
check_relatedness_vector(ts)
check_relatedness_vector(ts, centre=centre)

def test_missing_flanks(self):
@pytest.mark.parametrize("centre", (True, False))
def test_missing_flanks(self, centre):
ts = msprime.sim_ancestry(
2,
ploidy=1,
Expand All @@ -396,12 +420,13 @@ def test_missing_flanks(self):
assert ts.num_trees >= 2
ts = ts.keep_intervals([[20, 80]])
assert ts.first().interval == (0, 20)
check_relatedness_vector(ts, verbosity=2)
check_relatedness_vector(ts, centre=centre)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_suite_examples(self, ts):
@pytest.mark.parametrize("centre", (True, False))
def test_suite_examples(self, ts, centre):
if ts.num_samples > 0:
check_relatedness_vector(ts)
check_relatedness_vector(ts, centre=centre)

@pytest.mark.parametrize("n", [2, 3, 10])
def test_dangling_on_samples(self, n):
Expand All @@ -420,26 +445,28 @@ def test_dangling_on_samples(self, n):
np.testing.assert_array_almost_equal(D1, D2)

@pytest.mark.parametrize("n", [2, 3, 10])
def test_dangling_on_all(self, n):
@pytest.mark.parametrize("centre", (True, False))
def test_dangling_on_all(self, n, centre):
# Adding non sample branches below the samples does not alter
# the overall divergence *between* the samples
ts1 = tskit.Tree.generate_balanced(n).tree_sequence
D1 = check_relatedness_vector(ts1)
D1 = check_relatedness_vector(ts1, centre=centre)
tables = ts1.dump_tables()
for u in range(ts1.num_nodes):
v = tables.nodes.add_row(time=-1)
tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v)
tables.sort()
tables.build_index()
ts2 = tables.tree_sequence()
D2 = check_relatedness_vector(ts2, internal_checks=True)
D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre)
np.testing.assert_array_almost_equal(D1, D2)

def test_disconnected_non_sample_topology(self):
@pytest.mark.parametrize("centre", (True, False))
def test_disconnected_non_sample_topology(self, centre):
# Adding non sample branches below the samples does not alter
# the overall divergence *between* the samples
ts1 = tskit.Tree.generate_balanced(5).tree_sequence
D1 = check_relatedness_vector(ts1)
D1 = check_relatedness_vector(ts1, centre=centre)
tables = ts1.dump_tables()
# Add an extra bit of disconnected non-sample topology
u = tables.nodes.add_row(time=0)
Expand All @@ -448,5 +475,5 @@ def test_disconnected_non_sample_topology(self):
tables.sort()
tables.build_index()
ts2 = tables.tree_sequence()
D2 = check_relatedness_vector(ts2, internal_checks=True)
D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre)
np.testing.assert_array_almost_equal(D1, D2)
12 changes: 9 additions & 3 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7771,6 +7771,7 @@ def __weighted_vector_stat(
windows=None,
mode=None,
span_normalise=True,
centre=True,
):
W = np.asarray(W)
if len(W.shape) == 1:
Expand All @@ -7781,6 +7782,7 @@ def __weighted_vector_stat(
W,
mode=mode,
span_normalise=span_normalise,
centre=centre,
)
return stat

Expand Down Expand Up @@ -8429,7 +8431,7 @@ def genetic_relatedness_vector(
windows=None,
mode="site",
span_normalise=True,
polarised=True,
centre=True,
):
r"""
Computes the product of the genetic relatedness matrix and a vector of weights
Expand All @@ -8439,6 +8441,9 @@ def genetic_relatedness_vector(
:meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample
a and sample b, and the sum is over all samples in the tree sequence.
The relatedness used here corresponds to `polarised=True`; no unpolarised option
is available for this method.
:param numpy.ndarray W: An array of values with one row for each sample node and
one column for each set of weights.
:param list windows: An increasing list of breakpoints between the windows
Expand All @@ -8447,20 +8452,21 @@ def genetic_relatedness_vector(
(defaults to "site").
:param bool span_normalise: Whether to divide the result by the span of the
window (defaults to True).
:param bool centre: Whether to use the *centred* relatedness matrix or not:
see :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>`.
:return: A ndarray with shape equal to (num windows, num weights).
"""
if len(W) != self.num_samples:
raise ValueError(
"First trait dimension must be equal to number of samples."
)
if not polarised:
raise ValueError("genetic_relatedness_vector is not available unpolarised.")
return self.__weighted_vector_stat(
self._ll_tree_sequence.genetic_relatedness_vector,
W,
windows=windows,
mode=mode,
span_normalise=span_normalise,
centre=centre,
)

def trait_covariance(self, W, windows=None, mode="site", span_normalise=True):
Expand Down

0 comments on commit 86b87d7

Please sign in to comment.