Skip to content

Commit

Permalink
Experimental variant based method
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 5, 2023
1 parent 35fef5c commit d9810bd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 109 deletions.
129 changes: 48 additions & 81 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6620,122 +6620,89 @@ increment_divergence_matrix_pairs(const tsk_size_t len_A, const tsk_id_t *restri
}
}

static void
update_site_divergence(const tsk_tree_t *tree, tsk_id_t node,
const tsk_id_t *sample_index_map, tsk_size_t num_samples, tsk_id_t *restrict stack,
int8_t *restrict descending_bitset, tsk_id_t *restrict descending_list,
tsk_id_t *restrict not_descending_list, double *D)
{
const tsk_id_t *restrict left_child = tree->left_child;
const tsk_id_t *restrict right_sib = tree->right_sib;
int stack_top;
tsk_id_t a, u, v;
tsk_size_t j, num_descending, num_not_descending;

tsk_memset(descending_bitset, 0, num_samples * sizeof(*descending_bitset));
#include <tskit/genotypes.h>

stack_top = 0;
stack[stack_top] = node;
while (stack_top >= 0) {
u = stack[stack_top];
stack_top--;
a = sample_index_map[u];
if (a != TSK_NULL) {
descending_bitset[a] = 1;
}
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
stack_top++;
stack[stack_top] = v;
}
}

num_descending = 0;
num_not_descending = 0;
for (j = 0; j < num_samples; j++) {
if (descending_bitset[j]) {
descending_list[num_descending] = (tsk_id_t) j;
num_descending++;
} else {
not_descending_list[num_not_descending] = (tsk_id_t) j;
num_not_descending++;
static void
update_site_divergence(
const tsk_variant_t *variant, tsk_id_t *restrict A, tsk_id_t *restrict B, double *D)
{
const tsk_id_t *restrict genotypes = variant->genotypes;
tsk_id_t a;
tsk_size_t j, num_A, num_B;

for (a = 0; a < (tsk_id_t) variant->num_alleles - 1; a++) {
num_A = 0;
num_B = 0;
for (j = 0; j < variant->num_samples; j++) {
if (genotypes[j] == a) {
A[num_A] = (tsk_id_t) j;
num_A++;
} else {
B[num_B] = (tsk_id_t) j;
num_B++;
}
}
tsk_bug_assert(num_A + num_B == variant->num_samples);
increment_divergence_matrix_pairs(num_A, A, num_B, B, D);
}
tsk_bug_assert(num_descending + num_not_descending == num_samples);

increment_divergence_matrix_pairs(
num_descending, descending_list, num_not_descending, not_descending_list, D);
}

static int
tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples,
tsk_size_t num_windows, const double *restrict windows,
tsk_flags_t TSK_UNUSED(options), const tsk_id_t *restrict sample_index_map,
const tsk_id_t *restrict samples, tsk_size_t num_windows,
const double *restrict windows, tsk_flags_t TSK_UNUSED(options),
double *restrict result)
{
int ret = 0;
tsk_tree_t tree;
tsk_size_t i, tree_site, tree_mut;
tsk_site_t site;
tsk_mutation_t mut;
double left, right, span_left, span_right;
tsk_size_t i;
tsk_id_t site_id;
double left, right;
double *restrict D;
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
int8_t *descending_bitset = tsk_malloc(num_samples * sizeof(*descending_bitset));
const tsk_id_t num_sites = (tsk_id_t) self->tables->sites.num_rows;
const double *restrict sites_position = self->tables->sites.position;
tsk_id_t *descending_list = tsk_malloc(num_samples * sizeof(*descending_list));
tsk_id_t *not_descending_list
= tsk_malloc(num_samples * sizeof(*not_descending_list));
/* Do *not* use tsk_tree_get_size bound here because it gives a per-tree
* bound, not a global one! */
tsk_id_t *stack = tsk_malloc(num_nodes * sizeof(*stack));
tsk_variant_t variant;

ret = tsk_tree_init(&tree, self, 0);
ret = tsk_variant_init(
&variant, self, samples, num_samples, NULL, TSK_ISOLATED_NOT_MISSING);
if (ret != 0) {
goto out;
}

if (descending_bitset == NULL || descending_list == NULL
|| not_descending_list == NULL || stack == NULL) {
if (descending_list == NULL || not_descending_list == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

site_id = 0;
while (site_id < num_sites && sites_position[site_id] < windows[0]) {
site_id++;
}

for (i = 0; i < num_windows; i++) {
left = windows[i];
right = windows[i + 1];
D = result + i * num_samples * num_samples;
ret = tsk_tree_seek(&tree, left, 0);
if (ret != 0) {
goto out;
}
while (tree.interval.left < right && tree.index != -1) {
span_left = TSK_MAX(tree.interval.left, left);
span_right = TSK_MIN(tree.interval.right, right);

for (tree_site = 0; tree_site < tree.sites_length; tree_site++) {
site = tree.sites[tree_site];
if (span_left <= site.position && site.position < span_right) {
for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) {
mut = site.mutations[tree_mut];
update_site_divergence(&tree, mut.node, sample_index_map,
num_samples, stack, descending_bitset, descending_list,
not_descending_list, D);
}
}
}

ret = tsk_tree_next(&tree);
if (ret < 0) {
if (site_id < num_sites) {
tsk_bug_assert(sites_position[site_id] >= left);
}
while (site_id < num_sites && sites_position[site_id] < right) {
ret = tsk_variant_decode(&variant, site_id, 0);
if (ret != 0) {
goto out;
}
update_site_divergence(&variant, descending_list, not_descending_list, D);
site_id++;
}
}
ret = 0;
out:
tsk_tree_free(&tree);
tsk_safe_free(descending_bitset);
tsk_variant_free(&variant);
tsk_safe_free(descending_list);
tsk_safe_free(not_descending_list);
tsk_safe_free(stack);
return ret;
}

Expand Down Expand Up @@ -6859,7 +6826,7 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
} else {
tsk_bug_assert(stat_site);
ret = tsk_treeseq_divergence_matrix_site(
self, n, num_windows, windows, options, sample_index_map, result);
self, n, samples, num_windows, windows, options, result);
}
if (ret != 0) {
goto out;
Expand Down
54 changes: 26 additions & 28 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,29 @@ def site_divergence_matrix(ts, windows=None, samples=None):
sample_index_map[samples] = np.arange(n)
is_descendant = np.zeros(n, dtype=bool)
D = np.zeros((num_windows, n, n))
tree = tskit.Tree(ts)
site_id = 0
while site_id < ts.num_sites and ts.sites_position[site_id] < windows[0]:
site_id += 1

# Note we have to use isolated_as_missing here because we're working with
# non-sample nodes. There are tricky problems here later with missing data.
variant = tskit.Variant(ts, samples=samples, isolated_as_missing=False)
for i in range(num_windows):
left = windows[i]
right = windows[i + 1]
tree.seek(left)
# Iterate over the trees in this window
while tree.interval.left < right and tree.index != -1:
span_left = max(tree.interval.left, left)
span_right = min(tree.interval.right, right)
for site in tree.sites():
if span_left <= site.position < span_right:
for mutation in site.mutations:
descendants = []
for u in tree.nodes(mutation.node):
if sample_index_map[u] != -1:
is_descendant[sample_index_map[u]] = True

descendants = np.where(is_descendant)[0]
not_descendants = np.where(np.logical_not(is_descendant))[0]
for j in descendants:
for k in not_descendants:
D[i, j, k] += 1
D[i, k, j] += 1
is_descendant[:] = False
tree.next()
if site_id < ts.num_sites:
assert ts.sites_position[site_id] >= left
while site_id < ts.num_sites and ts.sites_position[site_id] < right:
variant.decode(site_id)
max_allele = np.max(variant.genotypes)
for a in range(max_allele):
A = np.where(variant.genotypes == a)[0]
B = np.where(variant.genotypes != a)[0]
for j in A:
for k in B:
D[i, j, k] += 1
D[i, k, j] += 1
site_id += 1
if not windows_specified:
D = D[0]
return D
Expand Down Expand Up @@ -350,15 +348,15 @@ class TestExamplesWithAnswer:
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_single_tree_zero_samples(self, mode):
ts = tskit.Tree.generate_balanced(2).tree_sequence
D = check_divmat(ts, samples=[], mode="site")
D = check_divmat(ts, samples=[], mode=mode)
assert D.shape == (0, 0)

@pytest.mark.parametrize("num_windows", [1, 2, 3, 5])
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_single_tree_zero_samples_windows(self, num_windows, mode):
ts = tskit.Tree.generate_balanced(2).tree_sequence
windows = np.linspace(0, ts.sequence_length, num=num_windows + 1)
D = check_divmat(ts, samples=[], windows=windows, mode="site")
D = check_divmat(ts, samples=[], windows=windows, mode=mode)
assert D.shape == (num_windows, 0, 0)

@pytest.mark.parametrize("m", [0, 1, 2, 10])
Expand Down Expand Up @@ -804,10 +802,10 @@ def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"):
np.testing.assert_allclose(D1, D2, atol=atol)
else:
assert mode == "site"
if np.any(ts.mutations_parent != tskit.NULL):
# The stats API computes something slightly different when we have
# recurrent mutations, so fall back to the naive version.
D2 = site_divergence_matrix(ts, windows=windows, samples=samples)
# if np.any(ts.mutations_parent != tskit.NULL):
# # The stats API computes something slightly different when we have
# # recurrent mutations, so fall back to the naive version.
# D2 = site_divergence_matrix(ts, windows=windows, samples=samples)
np.testing.assert_array_equal(D1, D2)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
Expand Down

0 comments on commit d9810bd

Please sign in to comment.