diff --git a/c/tskit/trees.c b/c/tskit/trees.c index f51a4a0333..0ec70c7687 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2834,9 +2834,6 @@ tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double righ } if (0 < all_samples && all_samples < self->num_samples) { - if (!polarised) { - x *= 0.5; - } afs_size = result_dims[num_sample_sets]; afs = result + afs_size * window_index; for (k = 0; k < num_sample_sets; k++) { diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 99e8e11c55..a272c4f8cf 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -3395,6 +3395,13 @@ def foldit(A): return B +def fold_windowed(X): + Y = np.zeros(X.shape) + for k in range(X.shape[0]): + Y[k] = foldit(X[k]) + return Y + + class TestFold: """ Tests for the fold operation used in the AFS. @@ -3432,6 +3439,29 @@ def test_examples(self): Ef = np.array([8.0, 8.0, 8.0, 8.0, 4.0, 0.0, 0.0, 0.0, 0.0]) assert np.all(foldit(E) == Ef) + def test_branch_folded(self): + ts = msprime.sim_ancestry(10, random_seed=1, sequence_length=10) + folded = ts.allele_frequency_spectrum( + windows=[0, 5, 8, 9, 10], mode="branch", polarised=False + ) + unfolded = ts.allele_frequency_spectrum( + windows=[0, 5, 8, 9, 10], mode="branch", polarised=True + ) + assert np.allclose(fold_windowed(unfolded), folded) + + def test_site_folded(self): + ts = msprime.sim_ancestry(10, random_seed=1, sequence_length=10) + ts = msprime.sim_mutations(ts, rate=1, random_seed=1, discrete_genome=False) + for s in ts.sites(): + assert len(s.mutations) == 1 + folded = ts.allele_frequency_spectrum( + windows=[0, 5, 8, 9, 10], mode="site", polarised=False, span_normalise=False + ) + unfolded = ts.allele_frequency_spectrum( + windows=[0, 5, 8, 9, 10], mode="site", polarised=True, span_normalise=False + ) + assert np.allclose(fold_windowed(unfolded), folded) + def naive_site_allele_frequency_spectrum( ts, sample_sets, windows=None, polarised=False, span_normalise=True @@ -3513,11 +3543,9 @@ def naive_branch_allele_frequency_spectrum( if 0 < t.num_samples(node) < ts.num_samples: x = [tree.num_tracked_samples(node) for tree in trees] # Note x must be a tuple for indexing to work - if polarised: - S[tuple(x)] += t.branch_length(node) * tr_len - else: + if not polarised: x = fold(x, out_dim) - S[tuple(x)] += 0.5 * t.branch_length(node) * tr_len + S[tuple(x)] += t.branch_length(node) * tr_len # Advance the trees more = [tree.next() for tree in trees] @@ -3582,7 +3610,6 @@ def update_result(window_index, u, right): c = count[u, :num_sample_sets] if not polarised: c = fold(c, out_dim) - x *= 0.5 index = tuple([window_index] + list(c)) result[index] += x last_update[u] = right @@ -3893,8 +3920,6 @@ def verify(self, ts): polarised=polarised, span_normalise=True, ) - if not polarised: - afs *= 2 afs_sum = [np.sum(window) for window in afs] self.assertArrayAlmostEqual(afs_sum, tbl)