diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 39f2a063e0..60ad93e920 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1132,7 +1132,6 @@ test_single_tree_divergence_matrix_multi_root(void) int ret; double result[16]; double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 }; - double D_site[16] = { 0, 4, 6, 6, 4, 0, 6, 6, 6, 6, 0, 8, 6, 6, 8, 0 }; const char *nodes = "1 0 -1 -1\n" "1 0 -1 -1\n" /* 2.00┊ 5 ┊ */ @@ -1142,7 +1141,7 @@ test_single_tree_divergence_matrix_multi_root(void) "0 2 -1 -1\n"; /* 0 * * * * 1 */ const char *edges = "0 1 4 0,1\n" "0 1 5 2,3\n"; - /* Two mutations per branch unit so we get twice branch length value */ + /* Two mutations per branch */ const char *sites = "0.1 A\n" "0.2 A\n" "0.3 A\n" @@ -1166,9 +1165,8 @@ test_single_tree_divergence_matrix_multi_root(void) CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(16, result, D_branch); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); - CU_ASSERT_EQUAL_FATAL(ret, 0); - assert_arrays_almost_equal(16, result, D_site); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); tsk_treeseq_free(&ts); } @@ -2041,6 +2039,13 @@ test_simplest_divergence_matrix(void) ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + sample_ids[0] = 1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + tsk_treeseq_free(&ts); } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 8a3d0afc95..83657f0266 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6597,43 +6597,62 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s return ret; } -static tsk_size_t -count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, - const double *restrict time, const tsk_size_t *restrict mutations_per_node) +// FIXME see #2817 +// Just including this here for now as it's the simplest option. Everything +// will probably move to stats.[c,h] in the near future though, and it +// can pull in ``genotypes.h`` without issues. +#include + +static void +update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, + const tsk_size_t *restrict offsets, double *D) + { - double tu, tv; - tsk_size_t count = 0; + const tsk_size_t num_alleles = var->num_alleles; + const tsk_id_t n = (tsk_id_t) var->num_samples; - tu = time[u]; - tv = time[v]; - while (u != v) { - if (tu < tv) { - count += mutations_per_node[u]; - u = parent[u]; - if (u == TSK_NULL) { - break; - } - tu = time[u]; - } else { - count += mutations_per_node[v]; - v = parent[v]; - if (v == TSK_NULL) { - break; + tsk_size_t a, b, j, k; + tsk_id_t u, v; + + for (a = 0; a < num_alleles; a++) { + for (b = a + 1; b < num_alleles; b++) { + for (j = offsets[a]; j < offsets[a + 1]; j++) { + for (k = offsets[b]; k < offsets[b + 1]; k++) { + u = A[j]; + v = A[k]; + /* Only increment the upper triangle to (hopefully) improve memory + * access patterns */ + if (u > v) { + v = A[j]; + u = A[k]; + } + D[u * n + v]++; + } } - tv = time[v]; } } - if (u != v) { - while (u != TSK_NULL) { - count += mutations_per_node[u]; - u = parent[u]; - } - while (v != TSK_NULL) { - count += mutations_per_node[v]; - v = parent[v]; +} + +static void +group_alleles(const tsk_variant_t *var, tsk_id_t *restrict A, tsk_size_t *offsets) +{ + const tsk_size_t n = var->num_samples; + const tsk_id_t *restrict genotypes = var->genotypes; + tsk_id_t a; + tsk_size_t j, k; + + k = 0; + offsets[0] = 0; + for (a = 0; a < (tsk_id_t) var->num_alleles; a++) { + offsets[a + 1] = offsets[a]; + for (j = 0; j < n; j++) { + if (genotypes[j] == a) { + offsets[a + 1]++; + A[k] = (tsk_id_t) j; + k++; + } } } - return count; } static int @@ -6643,72 +6662,100 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam double *restrict result) { int ret = 0; - tsk_tree_t tree; - const tsk_size_t n = num_samples; - const tsk_size_t num_nodes = self->tables->nodes.num_rows; - const double *restrict nodes_time = self->tables->nodes.time; - tsk_size_t i, j, k, tree_site, tree_mut; - tsk_site_t site; - tsk_mutation_t mut; - tsk_id_t u, v; - double left, right, span_left, span_right; + tsk_size_t i; + tsk_id_t site_id; + double left, right; double *restrict D; - tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node)); - - ret = tsk_tree_init(&tree, self, 0); + const tsk_id_t num_sites = (tsk_id_t) self->tables->sites.num_rows; + const double *restrict sites_position = self->tables->sites.position; + tsk_id_t *A = tsk_malloc(num_samples * sizeof(*A)); + /* Allocate the allele offsets at the first variant */ + tsk_size_t max_alleles = 0; + tsk_size_t *allele_offsets = NULL; + tsk_variant_t variant; + + ret = tsk_variant_init( + &variant, self, samples, num_samples, NULL, TSK_ISOLATED_NOT_MISSING); if (ret != 0) { goto out; } - if (mutations_per_node == NULL) { + + if (A == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } + site_id = 0; + while (site_id < num_sites && sites_position[site_id] < windows[0]) { + site_id++; + } + for (i = 0; i < num_windows; i++) { left = windows[i]; right = windows[i + 1]; - D = result + i * n * n; - ret = tsk_tree_seek(&tree, left, 0); - if (ret != 0) { - goto out; - } - while (tree.interval.left < right && tree.index != -1) { - span_left = TSK_MAX(tree.interval.left, left); - span_right = TSK_MIN(tree.interval.right, right); + D = result + i * num_samples * num_samples; - /* NOTE: we could avoid this full memset across all nodes by doing - * the same loops again and decrementing at the end of the main - * tree-loop. It's probably not worth it though, because of the - * overwhelming O(n^2) below */ - tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node)); - for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { - site = tree.sites[tree_site]; - if (span_left <= site.position && site.position < span_right) { - for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { - mut = site.mutations[tree_mut]; - mutations_per_node[mut.node]++; - } - } + if (site_id < num_sites) { + tsk_bug_assert(sites_position[site_id] >= left); + } + while (site_id < num_sites && sites_position[site_id] < right) { + ret = tsk_variant_decode(&variant, site_id, 0); + if (ret != 0) { + goto out; } - - for (j = 0; j < n; j++) { - u = samples[j]; - for (k = j + 1; k < n; k++) { - v = samples[k]; - D[j * n + k] += (double) count_mutations_on_path( - u, v, tree.parent, nodes_time, mutations_per_node); + if (variant.num_alleles > max_alleles) { + /* could do some kind of doubling here, but there's no + * point - just keep it simple for testing. */ + max_alleles = variant.num_alleles; + tsk_safe_free(allele_offsets); + allele_offsets = tsk_malloc((max_alleles + 1) * sizeof(*allele_offsets)); + if (allele_offsets == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; } } - ret = tsk_tree_next(&tree); - if (ret < 0) { - goto out; - } + group_alleles(&variant, A, allele_offsets); + update_site_divergence(&variant, A, allele_offsets, D); + site_id++; } } ret = 0; out: - tsk_tree_free(&tree); - tsk_safe_free(mutations_per_node); + tsk_variant_free(&variant); + tsk_safe_free(A); + tsk_safe_free(allele_offsets); + return ret; +} + +static int +get_sample_index_map(const tsk_size_t num_nodes, const tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_id_t **ret_sample_index_map) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map)); + + if (sample_index_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* Assign the output pointer here so that it will be freed in the case + * of an error raised in the input checking */ + *ret_sample_index_map = sample_index_map; + + for (j = 0; j < num_nodes; j++) { + sample_index_map[j] = TSK_NULL; + } + for (j = 0; j < num_samples; j++) { + u = samples[j]; + if (sample_index_map[u] != TSK_NULL) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + sample_index_map[u] = (tsk_id_t) j; + } +out: return ret; } @@ -6739,9 +6786,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, const tsk_id_t *samples = self->samples; tsk_size_t n = self->num_samples; const double default_windows[] = { 0, self->tables->sequence_length }; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); + tsk_id_t *sample_index_map = NULL; if (stat_node) { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; @@ -6785,6 +6834,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, } } + ret = get_sample_index_map(num_nodes, n, samples, &sample_index_map); + if (ret != 0) { + goto out; + } + tsk_memset(result, 0, num_windows * n * n * sizeof(*result)); if (stat_branch) { @@ -6801,5 +6855,6 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, fill_lower_triangle(result, n, num_windows); out: + tsk_safe_free(sample_index_map); return ret; } diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index acb2403d41..96137617cc 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -22,8 +22,6 @@ """ Test cases for divergence matrix based pairwise stats """ -import collections - import msprime import numpy as np import pytest @@ -266,10 +264,19 @@ def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="site"): return out -def rootward_path(tree, u, v): - while u != v: - yield u - u = tree.parent(u) +def group_alleles(genotypes, num_alleles): + n = genotypes.shape[0] + A = np.zeros(n, dtype=int) + offsets = np.zeros(num_alleles + 1, dtype=int) + k = 0 + for a in range(num_alleles): + offsets[a + 1] = offsets[a] + for j in range(n): + if genotypes[j] == a: + offsets[a + 1] += 1 + A[k] = j + k += 1 + return A, offsets def site_divergence_matrix(ts, windows=None, samples=None): @@ -279,42 +286,33 @@ def site_divergence_matrix(ts, windows=None, samples=None): samples = ts.samples() if samples is None else samples n = len(samples) + sample_index_map = np.zeros(ts.num_nodes, dtype=int) - 1 + sample_index_map[samples] = np.arange(n) D = np.zeros((num_windows, n, n)) - tree = tskit.Tree(ts) + site_id = 0 + while site_id < ts.num_sites and ts.sites_position[site_id] < windows[0]: + site_id += 1 + + # Note we have to use isolated_as_missing here because we're working with + # non-sample nodes. There are tricky problems here later with missing data. + variant = tskit.Variant(ts, samples=samples, isolated_as_missing=False) for i in range(num_windows): left = windows[i] right = windows[i + 1] - tree.seek(left) - # Iterate over the trees in this window - while tree.interval.left < right and tree.index != -1: - span_left = max(tree.interval.left, left) - span_right = min(tree.interval.right, right) - mutations_per_node = collections.Counter() - for site in tree.sites(): - if span_left <= site.position < span_right: - for mutation in site.mutations: - mutations_per_node[mutation.node] += 1 - for j in range(n): - u = samples[j] - for k in range(j + 1, n): - v = samples[k] - w = tree.mrca(u, v) - if w != tskit.NULL: - wu = w - wv = w - else: - wu = local_root(tree, u) - wv = local_root(tree, v) - du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) - dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) - # NOTE: we're just accumulating the raw mutation counts, not - # multiplying by span - D[i, j, k] += du + dv - tree.next() - # Fill out symmetric triangle in the matrix - for j in range(n): - for k in range(j + 1, n): - D[i, k, j] = D[i, j, k] + if site_id < ts.num_sites: + assert ts.sites_position[site_id] >= left + while site_id < ts.num_sites and ts.sites_position[site_id] < right: + variant.decode(site_id) + X, offsets = group_alleles(variant.genotypes, variant.num_alleles) + for j in range(variant.num_alleles): + A = X[offsets[j] : offsets[j + 1]] + for k in range(j + 1, variant.num_alleles): + B = X[offsets[k] : offsets[k + 1]] + for a in A: + for b in B: + D[i, a, b] += 1 + D[i, b, a] += 1 + site_id += 1 if not windows_specified: D = D[0] return D @@ -359,7 +357,7 @@ class TestExamplesWithAnswer: @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_zero_samples(self, mode): ts = tskit.Tree.generate_balanced(2).tree_sequence - D = check_divmat(ts, samples=[], mode="site") + D = check_divmat(ts, samples=[], mode=mode) assert D.shape == (0, 0) @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) @@ -367,7 +365,7 @@ def test_single_tree_zero_samples(self, mode): def test_single_tree_zero_samples_windows(self, num_windows, mode): ts = tskit.Tree.generate_balanced(2).tree_sequence windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) - D = check_divmat(ts, samples=[], windows=windows, mode="site") + D = check_divmat(ts, samples=[], windows=windows, mode=mode) assert D.shape == (num_windows, 0, 0) @pytest.mark.parametrize("m", [0, 1, 2, 10]) @@ -391,6 +389,8 @@ def test_single_tree_sites_per_branch(self, m): ) np.testing.assert_array_equal(D1, m * D2) + # Disabling while we're experimenting with variant-based definition + @pytest.mark.skip() @pytest.mark.parametrize("m", [0, 1, 2, 10]) def test_single_tree_mutations_per_branch(self, m): # 2.00┊ 6 ┊ @@ -415,6 +415,24 @@ def test_single_tree_mutations_per_branch(self, m): ) np.testing.assert_array_equal(D1, m * D2) + @pytest.mark.parametrize("n", [2, 3, 5]) + def test_single_tree_unique_sample_alleles(self, n): + tables = tskit.Tree.generate_balanced(n).tree_sequence.dump_tables() + tables.sites.add_row(position=0.5, ancestral_state="0") + for j in range(n): + tables.mutations.add_row(site=0, node=j, derived_state=f"{j + 1}") + ts = tables.tree_sequence() + check_divmat(ts, mode="site") + # D2 = np.array( + # [ + # [0.0, 2.0, 4.0, 4.0], + # [2.0, 0.0, 4.0, 4.0], + # [4.0, 4.0, 0.0, 2.0], + # [4.0, 4.0, 2.0, 0.0], + # ] + # ) + # np.testing.assert_array_equal(D1, m * D2) + @pytest.mark.parametrize("L", [0.1, 1, 2, 100]) def test_single_tree_sequence_length(self, L): # 2.00┊ 6 ┊ @@ -511,15 +529,8 @@ def test_single_tree_duplicate_samples(self, mode): # 0 1 ts = tskit.Tree.generate_balanced(4).tree_sequence ts = tsutil.insert_branch_sites(ts) - D1 = check_divmat(ts, samples=[0, 0, 1], compare_stats_api=False, mode=mode) - D2 = np.array( - [ - [0.0, 0.0, 2.0], - [0.0, 0.0, 2.0], - [2.0, 2.0, 0.0], - ] - ) - np.testing.assert_array_equal(D1, D2) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): + ts.divergence_matrix(samples=[0, 0, 1], mode=mode) @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_multiroot(self, mode): @@ -820,10 +831,6 @@ def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): np.testing.assert_allclose(D1, D2, atol=atol) else: assert mode == "site" - if np.any(ts.mutations_parent != tskit.NULL): - # The stats API computes something slightly different when we have - # recurrent mutations, so fall back to the naive version. - D2 = site_divergence_matrix(ts, windows=windows, samples=samples) np.testing.assert_array_equal(D1, D2) @pytest.mark.parametrize("ts", get_example_tree_sequences()) @@ -1062,3 +1069,47 @@ def test_examples(self, windows, num_chunks, expected): def test_bad_chunks(self, num_chunks): with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): tskit.TreeSequence._chunk_windows([0, 1], num_chunks) + + +class TestGroupAlleles: + @pytest.mark.parametrize( + ["G", "num_alleles", "A", "offsets"], + [ + ([0, 1], 2, [0, 1], [0, 1, 2]), + ([0, 1], 3, [0, 1], [0, 1, 2, 2]), + ([0, 2], 3, [0, 1], [0, 1, 1, 2]), + ([1, 0], 2, [1, 0], [0, 1, 2]), + ([0, 0, 0, 1, 1, 1], 2, [0, 1, 2, 3, 4, 5], [0, 3, 6]), + ([0, 0], 1, [0, 1], [0, 2]), + ([2, 2], 3, [0, 1], [0, 0, 0, 2]), + ], + ) + def test_examples(self, G, num_alleles, A, offsets): + A1, offsets1 = group_alleles(np.array(G), num_alleles) + assert list(A) == list(A1) + assert list(offsets) == list(offsets1) + + def test_simple_simulation(self): + ts = msprime.sim_ancestry( + 15, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1234) + assert ts.num_mutations > 10 + for var in ts.variants(): + A, offsets = group_alleles(var.genotypes, var.num_alleles) + allele_samples = [[] for _ in range(var.num_alleles)] + for j, a in enumerate(var.genotypes): + allele_samples[a].append(j) + + assert len(offsets) == var.num_alleles + 1 + assert offsets[0] == 0 + assert offsets[-1] == ts.num_samples + assert np.all(np.diff(offsets) >= 0) + for j in range(var.num_alleles): + a = A[offsets[j] : offsets[j + 1]] + assert list(a) == list(allele_samples[j])