Skip to content

Commit

Permalink
remove 1/2 in branch AFS: closes #2925
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp authored and benjeffery committed Jun 17, 2024
1 parent e94fd21 commit f7d6bff
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
3 changes: 0 additions & 3 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2834,9 +2834,6 @@ tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double righ
}

if (0 < all_samples && all_samples < self->num_samples) {
if (!polarised) {
x *= 0.5;
}
afs_size = result_dims[num_sample_sets];
afs = result + afs_size * window_index;
for (k = 0; k < num_sample_sets; k++) {
Expand Down
41 changes: 33 additions & 8 deletions python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2018-2023 Tskit Developers
# Copyright (c) 2018-2024 Tskit Developers
# Copyright (C) 2016 University of Oxford
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -3395,6 +3395,13 @@ def foldit(A):
return B


def fold_windowed(X):
Y = np.zeros(X.shape)
for k in range(X.shape[0]):
Y[k] = foldit(X[k])
return Y


class TestFold:
"""
Tests for the fold operation used in the AFS.
Expand Down Expand Up @@ -3432,6 +3439,29 @@ def test_examples(self):
Ef = np.array([8.0, 8.0, 8.0, 8.0, 4.0, 0.0, 0.0, 0.0, 0.0])
assert np.all(foldit(E) == Ef)

def test_branch_folded(self):
ts = msprime.sim_ancestry(10, random_seed=1, sequence_length=10)
folded = ts.allele_frequency_spectrum(
windows=[0, 5, 8, 9, 10], mode="branch", polarised=False
)
unfolded = ts.allele_frequency_spectrum(
windows=[0, 5, 8, 9, 10], mode="branch", polarised=True
)
assert np.allclose(fold_windowed(unfolded), folded)

def test_site_folded(self):
ts = msprime.sim_ancestry(10, random_seed=1, sequence_length=10)
ts = msprime.sim_mutations(ts, rate=1, random_seed=1, discrete_genome=False)
for s in ts.sites():
assert len(s.mutations) == 1
folded = ts.allele_frequency_spectrum(
windows=[0, 5, 8, 9, 10], mode="site", polarised=False, span_normalise=False
)
unfolded = ts.allele_frequency_spectrum(
windows=[0, 5, 8, 9, 10], mode="site", polarised=True, span_normalise=False
)
assert np.allclose(fold_windowed(unfolded), folded)


def naive_site_allele_frequency_spectrum(
ts, sample_sets, windows=None, polarised=False, span_normalise=True
Expand Down Expand Up @@ -3513,11 +3543,9 @@ def naive_branch_allele_frequency_spectrum(
if 0 < t.num_samples(node) < ts.num_samples:
x = [tree.num_tracked_samples(node) for tree in trees]
# Note x must be a tuple for indexing to work
if polarised:
S[tuple(x)] += t.branch_length(node) * tr_len
else:
if not polarised:
x = fold(x, out_dim)
S[tuple(x)] += 0.5 * t.branch_length(node) * tr_len
S[tuple(x)] += t.branch_length(node) * tr_len

# Advance the trees
more = [tree.next() for tree in trees]
Expand Down Expand Up @@ -3582,7 +3610,6 @@ def update_result(window_index, u, right):
c = count[u, :num_sample_sets]
if not polarised:
c = fold(c, out_dim)
x *= 0.5
index = tuple([window_index] + list(c))
result[index] += x
last_update[u] = right
Expand Down Expand Up @@ -3893,8 +3920,6 @@ def verify(self, ts):
polarised=polarised,
span_normalise=True,
)
if not polarised:
afs *= 2
afs_sum = [np.sum(window) for window in afs]
self.assertArrayAlmostEqual(afs_sum, tbl)

Expand Down

0 comments on commit f7d6bff

Please sign in to comment.