From d9810bd46be3a202d59e152488c3e83210b6e6f1 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 5 Aug 2023 11:09:53 +0100 Subject: [PATCH] Experimental variant based method --- c/tskit/trees.c | 129 ++++++++++++++---------------------- python/tests/test_divmat.py | 54 ++++++++------- 2 files changed, 74 insertions(+), 109 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 6a914582aa..05976514bb 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6620,122 +6620,89 @@ increment_divergence_matrix_pairs(const tsk_size_t len_A, const tsk_id_t *restri } } -static void -update_site_divergence(const tsk_tree_t *tree, tsk_id_t node, - const tsk_id_t *sample_index_map, tsk_size_t num_samples, tsk_id_t *restrict stack, - int8_t *restrict descending_bitset, tsk_id_t *restrict descending_list, - tsk_id_t *restrict not_descending_list, double *D) -{ - const tsk_id_t *restrict left_child = tree->left_child; - const tsk_id_t *restrict right_sib = tree->right_sib; - int stack_top; - tsk_id_t a, u, v; - tsk_size_t j, num_descending, num_not_descending; - - tsk_memset(descending_bitset, 0, num_samples * sizeof(*descending_bitset)); +#include - stack_top = 0; - stack[stack_top] = node; - while (stack_top >= 0) { - u = stack[stack_top]; - stack_top--; - a = sample_index_map[u]; - if (a != TSK_NULL) { - descending_bitset[a] = 1; - } - for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { - stack_top++; - stack[stack_top] = v; - } - } - - num_descending = 0; - num_not_descending = 0; - for (j = 0; j < num_samples; j++) { - if (descending_bitset[j]) { - descending_list[num_descending] = (tsk_id_t) j; - num_descending++; - } else { - not_descending_list[num_not_descending] = (tsk_id_t) j; - num_not_descending++; +static void +update_site_divergence( + const tsk_variant_t *variant, tsk_id_t *restrict A, tsk_id_t *restrict B, double *D) +{ + const tsk_id_t *restrict genotypes = variant->genotypes; + tsk_id_t a; + tsk_size_t j, num_A, num_B; + + for (a = 0; a < (tsk_id_t) variant->num_alleles - 1; a++) { + num_A = 0; + num_B = 0; + for (j = 0; j < variant->num_samples; j++) { + if (genotypes[j] == a) { + A[num_A] = (tsk_id_t) j; + num_A++; + } else { + B[num_B] = (tsk_id_t) j; + num_B++; + } } + tsk_bug_assert(num_A + num_B == variant->num_samples); + increment_divergence_matrix_pairs(num_A, A, num_B, B, D); } - tsk_bug_assert(num_descending + num_not_descending == num_samples); - - increment_divergence_matrix_pairs( - num_descending, descending_list, num_not_descending, not_descending_list, D); } static int tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples, - tsk_size_t num_windows, const double *restrict windows, - tsk_flags_t TSK_UNUSED(options), const tsk_id_t *restrict sample_index_map, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t TSK_UNUSED(options), double *restrict result) { int ret = 0; - tsk_tree_t tree; - tsk_size_t i, tree_site, tree_mut; - tsk_site_t site; - tsk_mutation_t mut; - double left, right, span_left, span_right; + tsk_size_t i; + tsk_id_t site_id; + double left, right; double *restrict D; - const tsk_size_t num_nodes = self->tables->nodes.num_rows; - int8_t *descending_bitset = tsk_malloc(num_samples * sizeof(*descending_bitset)); + const tsk_id_t num_sites = (tsk_id_t) self->tables->sites.num_rows; + const double *restrict sites_position = self->tables->sites.position; tsk_id_t *descending_list = tsk_malloc(num_samples * sizeof(*descending_list)); tsk_id_t *not_descending_list = tsk_malloc(num_samples * sizeof(*not_descending_list)); - /* Do *not* use tsk_tree_get_size bound here because it gives a per-tree - * bound, not a global one! */ - tsk_id_t *stack = tsk_malloc(num_nodes * sizeof(*stack)); + tsk_variant_t variant; - ret = tsk_tree_init(&tree, self, 0); + ret = tsk_variant_init( + &variant, self, samples, num_samples, NULL, TSK_ISOLATED_NOT_MISSING); if (ret != 0) { goto out; } - if (descending_bitset == NULL || descending_list == NULL - || not_descending_list == NULL || stack == NULL) { + if (descending_list == NULL || not_descending_list == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } + site_id = 0; + while (site_id < num_sites && sites_position[site_id] < windows[0]) { + site_id++; + } + for (i = 0; i < num_windows; i++) { left = windows[i]; right = windows[i + 1]; D = result + i * num_samples * num_samples; - ret = tsk_tree_seek(&tree, left, 0); - if (ret != 0) { - goto out; - } - while (tree.interval.left < right && tree.index != -1) { - span_left = TSK_MAX(tree.interval.left, left); - span_right = TSK_MIN(tree.interval.right, right); - for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { - site = tree.sites[tree_site]; - if (span_left <= site.position && site.position < span_right) { - for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { - mut = site.mutations[tree_mut]; - update_site_divergence(&tree, mut.node, sample_index_map, - num_samples, stack, descending_bitset, descending_list, - not_descending_list, D); - } - } - } - - ret = tsk_tree_next(&tree); - if (ret < 0) { + if (site_id < num_sites) { + tsk_bug_assert(sites_position[site_id] >= left); + } + while (site_id < num_sites && sites_position[site_id] < right) { + ret = tsk_variant_decode(&variant, site_id, 0); + if (ret != 0) { goto out; } + update_site_divergence(&variant, descending_list, not_descending_list, D); + site_id++; } } ret = 0; out: - tsk_tree_free(&tree); - tsk_safe_free(descending_bitset); + tsk_variant_free(&variant); tsk_safe_free(descending_list); tsk_safe_free(not_descending_list); - tsk_safe_free(stack); return ret; } @@ -6859,7 +6826,7 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, } else { tsk_bug_assert(stat_site); ret = tsk_treeseq_divergence_matrix_site( - self, n, num_windows, windows, options, sample_index_map, result); + self, n, samples, num_windows, windows, options, result); } if (ret != 0) { goto out; diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index c6df718f83..362a7baa6b 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -281,31 +281,29 @@ def site_divergence_matrix(ts, windows=None, samples=None): sample_index_map[samples] = np.arange(n) is_descendant = np.zeros(n, dtype=bool) D = np.zeros((num_windows, n, n)) - tree = tskit.Tree(ts) + site_id = 0 + while site_id < ts.num_sites and ts.sites_position[site_id] < windows[0]: + site_id += 1 + + # Note we have to use isolated_as_missing here because we're working with + # non-sample nodes. There are tricky problems here later with missing data. + variant = tskit.Variant(ts, samples=samples, isolated_as_missing=False) for i in range(num_windows): left = windows[i] right = windows[i + 1] - tree.seek(left) - # Iterate over the trees in this window - while tree.interval.left < right and tree.index != -1: - span_left = max(tree.interval.left, left) - span_right = min(tree.interval.right, right) - for site in tree.sites(): - if span_left <= site.position < span_right: - for mutation in site.mutations: - descendants = [] - for u in tree.nodes(mutation.node): - if sample_index_map[u] != -1: - is_descendant[sample_index_map[u]] = True - - descendants = np.where(is_descendant)[0] - not_descendants = np.where(np.logical_not(is_descendant))[0] - for j in descendants: - for k in not_descendants: - D[i, j, k] += 1 - D[i, k, j] += 1 - is_descendant[:] = False - tree.next() + if site_id < ts.num_sites: + assert ts.sites_position[site_id] >= left + while site_id < ts.num_sites and ts.sites_position[site_id] < right: + variant.decode(site_id) + max_allele = np.max(variant.genotypes) + for a in range(max_allele): + A = np.where(variant.genotypes == a)[0] + B = np.where(variant.genotypes != a)[0] + for j in A: + for k in B: + D[i, j, k] += 1 + D[i, k, j] += 1 + site_id += 1 if not windows_specified: D = D[0] return D @@ -350,7 +348,7 @@ class TestExamplesWithAnswer: @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_zero_samples(self, mode): ts = tskit.Tree.generate_balanced(2).tree_sequence - D = check_divmat(ts, samples=[], mode="site") + D = check_divmat(ts, samples=[], mode=mode) assert D.shape == (0, 0) @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) @@ -358,7 +356,7 @@ def test_single_tree_zero_samples(self, mode): def test_single_tree_zero_samples_windows(self, num_windows, mode): ts = tskit.Tree.generate_balanced(2).tree_sequence windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) - D = check_divmat(ts, samples=[], windows=windows, mode="site") + D = check_divmat(ts, samples=[], windows=windows, mode=mode) assert D.shape == (num_windows, 0, 0) @pytest.mark.parametrize("m", [0, 1, 2, 10]) @@ -804,10 +802,10 @@ def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): np.testing.assert_allclose(D1, D2, atol=atol) else: assert mode == "site" - if np.any(ts.mutations_parent != tskit.NULL): - # The stats API computes something slightly different when we have - # recurrent mutations, so fall back to the naive version. - D2 = site_divergence_matrix(ts, windows=windows, samples=samples) + # if np.any(ts.mutations_parent != tskit.NULL): + # # The stats API computes something slightly different when we have + # # recurrent mutations, so fall back to the naive version. + # D2 = site_divergence_matrix(ts, windows=windows, samples=samples) np.testing.assert_array_equal(D1, D2) @pytest.mark.parametrize("ts", get_example_tree_sequences())