Skip to content

Commit

Permalink
Fix python version of alg
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 6, 2023
1 parent d9810bd commit 2271b33
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
3 changes: 3 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6630,6 +6630,9 @@ update_site_divergence(
tsk_id_t a;
tsk_size_t j, num_A, num_B;

/* This algorithm is incorrect - we are currently double counting. See the
* python version for the right way to do it.
*/
for (a = 0; a < (tsk_id_t) variant->num_alleles - 1; a++) {
num_A = 0;
num_B = 0;
Expand Down
52 changes: 36 additions & 16 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,6 @@ def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="site"):
return out


def rootward_path(tree, u, v):
while u != v:
yield u
u = tree.parent(u)


def site_divergence_matrix(ts, windows=None, samples=None):
windows_specified = windows is not None
windows = [0, ts.sequence_length] if windows is None else windows
Expand All @@ -295,14 +289,19 @@ def site_divergence_matrix(ts, windows=None, samples=None):
assert ts.sites_position[site_id] >= left
while site_id < ts.num_sites and ts.sites_position[site_id] < right:
variant.decode(site_id)
max_allele = np.max(variant.genotypes)
for a in range(max_allele):
A = np.where(variant.genotypes == a)[0]
B = np.where(variant.genotypes != a)[0]
for j in A:
for k in B:
D[i, j, k] += 1
D[i, k, j] += 1
# This could be implemented in fixed memory by num_alleles - 1 passes
# through the genotypes array
allele_samples = [[] for _ in range(variant.num_alleles)]
for j, a in enumerate(variant.genotypes):
allele_samples[a].append(j)
for j in range(variant.num_alleles):
A = allele_samples[j]
for k in range(j + 1, variant.num_alleles):
B = allele_samples[k]
for a in A:
for b in B:
D[i, a, b] += 1
D[i, b, a] += 1
site_id += 1
if not windows_specified:
D = D[0]
Expand Down Expand Up @@ -332,8 +331,8 @@ def check_divmat(
ts, windows=windows, samples=samples, mode=mode
)
# print("windows = ", windows)
# print(D1)
# print(D2)
print(D1)
print(D2)
np.testing.assert_allclose(D1, D2)
assert D1.shape == D2.shape
if compare_lib:
Expand Down Expand Up @@ -380,6 +379,8 @@ def test_single_tree_sites_per_branch(self, m):
)
np.testing.assert_array_equal(D1, m * D2)

# Disabling while we're experimenting with variant-based definition
@pytest.mark.skip()
@pytest.mark.parametrize("m", [0, 1, 2, 10])
def test_single_tree_mutations_per_branch(self, m):
# 2.00┊ 6 ┊
Expand All @@ -404,6 +405,24 @@ def test_single_tree_mutations_per_branch(self, m):
)
np.testing.assert_array_equal(D1, m * D2)

@pytest.mark.parametrize("n", [2, 3, 5])
def test_single_tree_unique_sample_alleles(self, n):
tables = tskit.Tree.generate_balanced(n).tree_sequence.dump_tables()
tables.sites.add_row(position=0.5, ancestral_state="0")
for j in range(n):
tables.mutations.add_row(site=0, node=j, derived_state=f"{j + 1}")
ts = tables.tree_sequence()
D1 = check_divmat(ts, mode="site")
# D2 = np.array(
# [
# [0.0, 2.0, 4.0, 4.0],
# [2.0, 0.0, 4.0, 4.0],
# [4.0, 4.0, 0.0, 2.0],
# [4.0, 4.0, 2.0, 0.0],
# ]
# )
# np.testing.assert_array_equal(D1, m * D2)

@pytest.mark.parametrize("L", [0.1, 1, 2, 100])
def test_single_tree_sequence_length(self, L):
# 2.00┊ 6 ┊
Expand Down Expand Up @@ -803,6 +822,7 @@ def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"):
else:
assert mode == "site"
# if np.any(ts.mutations_parent != tskit.NULL):
# print("HERE")
# # The stats API computes something slightly different when we have
# # recurrent mutations, so fall back to the naive version.
# D2 = site_divergence_matrix(ts, windows=windows, samples=samples)
Expand Down

0 comments on commit 2271b33

Please sign in to comment.