diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 58b52debb8..dbc3695492 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -24,8 +24,12 @@ """ import contextlib import io +from copy import deepcopy +from dataclasses import dataclass +from itertools import combinations from itertools import combinations_with_replacement from itertools import permutations +from itertools import product from typing import Any from typing import Callable from typing import Dict @@ -38,6 +42,7 @@ import pytest import tskit +from tests import tsutil from tests.test_highlevel import get_example_tree_sequences @@ -603,6 +608,99 @@ def two_site_count_stat( return result +def two_branch_count_stat( + ts: tskit.TreeSequence, + func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], + norm_func: Callable[ # No need for norm func, biallelic + [int, np.ndarray, int, int, np.ndarray, Dict[str, Any]], None + ], + num_sample_sets: int, + sample_set_sizes: np.ndarray, + sample_sets: BitSet, # TODO: implement sample sets + sample_index_map: np.ndarray, + row_sites: np.ndarray, # TODO: positions + col_sites: np.ndarray, # TODO: positions + polarised: bool, # TODO: polarisation +) -> np.ndarray: + """ + Compute a tree X tree LD matrix by walking along the tree sequence and + computing haplotype counts. This method incrementally adds and removes + branches from a tree sequence and updates the stat based on sample additions + and removals. We bifurcate the tree with a given branch on each locus and + intersect the samples under each branch to produce haplotype counts. We + output a full LD matrix for the entire tree sequence. + + :param ts: Tree sequence to gather data from. + :param func: Function used to compute each two-locus statistic. + :param norm_func: Not (YET) applicable for branch stats: TODO? + :param num_sample_sets: Number of sample sets that we will consider. + :param sample_set_sizes: Number of samples in each sample set. + :param sample_sets: BitSet of samples to compute stats for. We will only + consider these samples in our computations, resulting + in stats that are computed on subsets of the samples + on the tree sequence. + :param row_sites: Sites contained in the rows of the output matrix. + :param col_sites: Sites contained in the columns of the output matrix. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :returns: 3D array of results, dimensions (sample_sets, row_sites, col_sites). + """ + # state_dim = len(sample_set_sizes) + # params = {"sample_set_sizes": sample_set_sizes} + result = np.zeros( + (num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64 + ) + # These stats require at least 4 samples in the tree + if ts.num_samples < 4: + return result + + # TODO: get_pos_row_col_indices? + # sites, row_idx, col_idx = get_site_row_col_indices(row_sites, col_sites) + + stat = 0 + l_state = TreeState(ts) # State is initialized at tree -1 + r_state = TreeState(ts) + tmp = None # Tmp tree will be used for swapping below. + tmp_stat = None + + # Advance the fixed (left) tree state by adding all edges. Stat will be 0 + stat, l_state = compute_stat(ts, func, stat, l_state, TreeState(ts)) + for i in range(ts.num_trees): + # Compute the stat between the left and right state, advancing the right + # state to the SAME tree that the lefthand tree represents. If we start + # at tree 0 and tree -1, we end up at tree 0 and tree 0, etc. + stat, r_state = compute_stat(ts, func, stat, l_state, r_state) + # TODO: sample sets + result[0, l_state.pos.index, r_state.pos.index] = stat + # Continue to advance the righthand tree until we hit the end of the + # tree sequence, computing the stat for the rest of the upper triangle + # of the LD matrix. + for _ in range(i, ts.num_trees - 1): + stat, r_state = compute_stat(ts, func, stat, l_state, r_state) + # TODO: sample sets + result[0, l_state.pos.index, r_state.pos.index] = stat + # After the first iteration of this loop, we store the r_state and + # stat to be used as the lefthand tree in the next iteration. + if tmp is None: + tmp_stat = stat + tmp = deepcopy(r_state) + # We store the focal tree as the starting point for the next row in the + # LD matrix. Remember that in order to compute the association between + # tree 1 and tree 1, we need the righthand state to *start* at tree 0. + r_state = deepcopy(l_state) + if tmp is not None: + l_state = deepcopy(tmp) + stat = tmp_stat + tmp = None + + # Reflect the upper triangle of the LD matrix to the lower triangle. + tril_idx = np.tril_indices(len(result[0]), k=-1) + # TODO: sample sets + tril_idx = (np.zeros(len(tril_idx[0]), dtype=int), *tril_idx) + result[tril_idx] = result[tril_idx[0], tril_idx[2], tril_idx[1]] + return result + + def sample_sets_to_bit_array( ts: tskit.TreeSequence, sample_sets: List[List[int]] ) -> Tuple[np.ndarray, np.ndarray, BitSet]: @@ -648,6 +746,7 @@ def two_locus_count_stat( summary_func, norm_func, polarised, + mode, sites=None, sample_sets=None, ): @@ -692,18 +791,34 @@ def two_locus_count_stat( sample_index_map, ss_sizes, ss_bits = sample_sets_to_bit_array(ts, sample_sets) - result = two_site_count_stat( - ts, - summary_func, - norm_func, - len(ss_sizes), - ss_sizes, - ss_bits, - sample_index_map, - row_sites, - col_sites, - polarised, - ) + if mode == "site": + result = two_site_count_stat( + ts, + summary_func, + norm_func, + len(ss_sizes), + ss_sizes, + ss_bits, + sample_index_map, + row_sites, + col_sites, + polarised, + ) + elif mode == "branch": + result = two_branch_count_stat( + ts, + summary_func, + None, + 1, + [ts.num_samples], + None, + None, + range(ts.num_trees), + range(ts.num_trees), + False, + ) + else: + raise ValueError(f"Unknown mode: {mode}") # If there is one sample set, return a 2d numpy array of row/site LD if len(sample_sets) == 1: @@ -848,6 +963,48 @@ def pi2_summary_func( result[k] = p_A * (1 - p_A) * p_B * (1 - p_B) +# Unbiased estimators of pi2, dz, and d2. These are derived in Ragsdale 2019 +# (https://doi.org/10.1093/molbev/msz265) and can be used in place of the method +# outlined by McVean 2002. The reason for using haplotype counts in the branch +# methods is that we can compute statistics that cannot be represented by tMRCA +# covariance. With these unbiased estimators, we still reproduce the values +# estimated with tMRCA covariance. + +# TODO: update these summary functions to have the same function signature as +# the summary functions defined above. + + +def pi2_unbiased(w_AB, w_Ab, w_aB, n): + w_ab = n - (w_AB + w_Ab + w_aB) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((w_AB + w_Ab) * (w_aB + w_ab) * (w_AB + w_aB) * (w_Ab + w_ab)) + - ((w_AB * w_ab) * (w_AB + w_ab + (3 * w_Ab) + (3 * w_aB) - 1)) + - ((w_Ab * w_aB) * (w_Ab + w_aB + (3 * w_AB) + (3 * w_ab) - 1)) + ) + + +def dz_unbiased(w_AB, w_Ab, w_aB, n): + w_ab = n - (w_AB + w_Ab + w_aB) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ( + ((w_AB * w_ab) - (w_Ab * w_aB)) + * (w_aB + w_ab - w_AB - w_Ab) + * (w_Ab + w_ab - w_AB - w_aB) + ) + - ((w_AB * w_ab) * (w_AB + w_ab - w_Ab - w_aB - 2)) + - ((w_Ab * w_aB) * (w_Ab + w_aB - w_AB - w_ab - 2)) + ) + + +def d2_unbiased(w_AB, w_Ab, w_aB, n): + w_ab = n - (w_AB + w_Ab + w_aB) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((w_aB**2) * (w_Ab - 1) * w_Ab) + + ((w_ab - 1) * w_ab * (w_AB - 1) * w_AB) + - (w_aB * w_Ab * (w_Ab + (2 * w_ab * w_AB) - 1)) + ) + + SUMMARY_FUNCS = { "r": r_summary_func, "r2": r2_summary_func, @@ -856,6 +1013,9 @@ def pi2_summary_func( "D_prime": D_prime_summary_func, "pi2": pi2_summary_func, "Dz": Dz_summary_func, + "d2_unbiased": d2_unbiased, + "dz_unbiased": dz_unbiased, + "pi2_unbiased": pi2_unbiased, } NORM_METHOD = { @@ -866,6 +1026,9 @@ def pi2_summary_func( pi2_summary_func: norm_total_weighted, r_summary_func: norm_total_weighted, r2_summary_func: norm_hap_weighted, + d2_unbiased: None, + dz_unbiased: None, + pi2_unbiased: None, } POLARIZATION = { @@ -876,21 +1039,20 @@ def pi2_summary_func( pi2_summary_func: False, r_summary_func: True, r2_summary_func: False, + d2_unbiased: False, + dz_unbiased: False, + pi2_unbiased: False, } -def ld_matrix( - ts, - sample_sets=None, - sites=None, - stat="r2", -): +def ld_matrix(ts, sample_sets=None, sites=None, stat="r2", mode="site"): summary_func = SUMMARY_FUNCS[stat] return two_locus_count_stat( ts, summary_func, NORM_METHOD[summary_func], POLARIZATION[summary_func], + mode, sites=sites, sample_sets=sample_sets, ) @@ -1049,7 +1211,10 @@ def test_compare_to_ld_calculator(): np.testing.assert_array_almost_equal(ld_calc.get_r2_matrix(), ts.ld_matrix()) -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +@pytest.mark.parametrize( + "stat", + sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), +) def test_multiallelic_with_back_mutation(stat): ts = msprime.sim_ancestry( samples=4, recombination_rate=0.2, sequence_length=10, random_seed=1 @@ -1066,7 +1231,11 @@ def test_multiallelic_with_back_mutation(stat): if ts.id not in {"no_samples", "empty_ts"} ], ) -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +# TODO: port unbiased summary functions +@pytest.mark.parametrize( + "stat", + sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), +) def test_ld_matrix(ts, stat): np.testing.assert_array_equal(ld_matrix(ts, stat=stat), ts.ld_matrix(stat=stat)) @@ -1092,3 +1261,323 @@ def test_input_validation(): ts.ld_matrix(sites=[[1, 2], [2, 3], [3, 4]]) with pytest.raises(ValueError, match="must be a length 1 or 2 list"): ts.ld_matrix(sites=[]) + + +@dataclass +class TreeState: + """ + Class for storing tree state from one iteration to the next. This object + enables easy copying of the state for computing a matrix. + """ + + pos: tsutil.TreePosition # current position in the tree sequence + parent: np.ndarray # parent node of a given node (connected by an edge) + branch_len: np.ndarray # length of the branch above a particular child node + node_samples: BitSet # samples that exist under a given node, this is a + # bitset with a row for each node. + + def __init__(self, ts): + self.pos = tsutil.TreePosition(ts) + self.parent = -np.ones(ts.num_nodes, dtype=np.int64) + self.branch_len = np.zeros(ts.num_nodes, dtype=np.float64) + self.node_samples = BitSet(ts.num_samples, ts.num_nodes) + for s in ts.samples(): + self.node_samples.add(s, s) + + +def compute_stat_update(c, child_samples, A_state, B_state, stat_func, num_samples): + """Compute an update to the two-locus statistic for a single subset of the + tree being modified, relative to all subsets of the fixed tree. We perform + this operation for all samples edge being modified. For subsequent parent + nodes, we update the statistic by removing the existing contribution after + adding in the update contribution. + + i.e. if we're adding two samples ({3, 4}) to a node, if the parent node + contains {1, 2}, we first add the statistic for {1, 2, 3, 4}, then + subtract the stat for {1, 2}. + + :param c: Child node of the edge we're modifying + :param child_samples: Samples under the edge being added/removed + :param A_state: State for the tree contributing to the A samples (fixed) + :param A_state: State for the tree contributing to the B samples (modified) + :param stat_func: Function used to compute the two-locus statistic + :param num_samples: Number of samples in the tree sequence + :returns: The change to the statistic, given a single edge update in the tree + """ + stat = 0 + b_len = B_state.branch_len[c] + if b_len == 0: + return stat + AB_samples = BitSet(num_samples, 1) + node_samples_tmp = BitSet(num_samples, 1) + + for n in np.where(A_state.branch_len > 0)[0]: + a_len = A_state.branch_len[n] + # Samples under the modified edge and the current fixed tree node are AB + A_state.node_samples.intersect(n, B_state.node_samples, c, AB_samples) + w_AB = AB_samples.count(0) + w_A = A_state.node_samples.count(n) + w_Ab = w_A - w_AB + w_aB = B_state.node_samples.count(c) - w_AB + stat += stat_func(w_AB, w_Ab, w_aB, num_samples) * a_len * b_len + + # If we've begun our walk up the parents of the current edge removal, we + # must adjust the statistic for samples that were already present before + # addition or that remain after removal. + if child_samples is not None: + node_samples_tmp.union(0, B_state.node_samples, c) + node_samples_tmp.difference(0, child_samples, 0) + AB_samples.data[:] = 0 + # Zero out the bitset so that we can reuse it + A_state.node_samples.intersect(n, node_samples_tmp, 0, AB_samples) + w_AB = AB_samples.count(0) + w_Ab = w_A - w_AB + w_aB = node_samples_tmp.count(0) - w_AB + stat -= stat_func(w_AB, w_Ab, w_aB, num_samples) * a_len * b_len + return stat + + +def compute_stat(ts, stat_func, stat, l_state, r_state): + """Step between trees in a tree sequence, updating our two-locus statistic + as we add or remove edges. Since we're computing statistics for two loci, we + have a focal tree that remains constant, and a tree that is updated to + represent the tree we're comparing to. The lefthand tree is held constant + and the righthand tree is modified. The statistic is updated as we add and + remove branches, and when we reach the point where the righthand tree is + fully updated, the statistic will have been updated to the two-locus + statistic between both trees. + + For instance, if we pass in the l_state for tree 0 and the r_state for tree + 0, we will update the r_state until r_state contains the information for + tree 1. Then, the statistic will represent the LD between tree 1 and tree 2. + + Currenty, iteration happens in the forward direction. + + :param ts: The underlying tree sequence object that we're iterating across. + :param stat_func: A function that computes the two locus statistic, given + haplotype counts. + :param stat: The two-locus statistic computed between two trees. + :param l_state: The lefthand, constant state + :param r_state: The righthand, mutated state + :returns: A tuple containing the statistic between the two trees after + branch updates and the righthand tree state. + """ + time = ts.tables.nodes.time + r_pos = r_state.pos + # Iterate the right tree position to the next tree. The rest of the r_state + # data is not valid until the end of this function. + assert r_pos.next(), "out of bounds" + + child_samples = BitSet(ts.num_samples, 1) + for e in r_pos.out_range.order[r_pos.out_range.start : r_pos.out_range.stop]: + p = r_pos.ts.edges_parent[e] + c = r_pos.ts.edges_child[e] + child_samples.data[:] = 0 + child_samples.union(0, r_state.node_samples, c) # samples removed by this edge + + # Remove the LD contributed by the samples under removed edges. When + # we walk up the tree to propagate these changes to parents of the + # removed edge, we need to add back in the LD contributed by samples + # that aren't removed. We remove samples from the parents of the removed + # branch as we propagate changes upward + in_parent = None + while p != tskit.NULL: + stat -= compute_stat_update( + c, in_parent, l_state, r_state, stat_func, ts.num_samples + ) + if in_parent is not None: + # remove samples from the parents of the branch being removed + # we remove the child node after the first iteration + r_state.node_samples.difference(c, child_samples, 0) + in_parent = child_samples + c = p + p = r_state.parent[p] + r_state.node_samples.difference(c, child_samples, 0) + + # reset to the child of the edge being removed. + c = ts.edges_child[e] + r_state.branch_len[c] = 0 + r_state.parent[c] = tskit.NULL + + for e in r_pos.in_range.order[r_pos.in_range.start : r_pos.in_range.stop]: + p = r_pos.ts.edges_parent[e] + c = r_pos.ts.edges_child[e] + child_samples.data[:] = 0 + child_samples.union(0, r_state.node_samples, c) # samples added by this edge + r_state.branch_len[c] = time[p] - time[c] + r_state.parent[c] = p + + # Add the LD contributed by the samples under added edges. When we walk + # up the tree to propagate these changes to parents of the removed edge, + # we need to remove the LD contributed by samples that were already + # there + in_parent = None + while p != tskit.NULL: + r_state.node_samples.union(p, child_samples, 0) + stat += compute_stat_update( + c, in_parent, l_state, r_state, stat_func, ts.num_samples + ) + in_parent = child_samples + c = p + p = r_state.parent[p] + + return stat, r_state + + +# What follows is an implementation of two-locus statistics as described in +# McVean 2002 (https://doi.org/10.1093/genetics/162.2.987). We compute the +# covariance between coalescent times to produce expectations of coalescent +# times between three sampling patterns of samples. These expectations can be +# compined to produce D2, Dz, and pi2. These are for testing and to demonstrate +# conceptual parity between our method and McVean's method. + + +def compute_D2(x, y, ij, ijk, ijkl): + E_ijij = 0 + E_ijik = 0 + E_ijkl = 0 + for i, j in ij: + i_time = x.time(i) + j_time = x.time(j) + avg_time = (i_time + j_time) / 2 + E_ijij += (x.tmrca(i, j) - avg_time) * (y.tmrca(i, j) - avg_time) + for i, j, k in ijk: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + ij_time = (i_time + j_time) / 2 + ik_time = (i_time + k_time) / 2 + E_ijik += (x.tmrca(i, j) - ij_time) * (y.tmrca(i, k) - ik_time) + for i, j, k, l in ijkl: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + l_time = x.time(l) + ij_time = (i_time + j_time) / 2 + kl_time = (k_time + l_time) / 2 + E_ijkl += (x.tmrca(i, j) - ij_time) * (y.tmrca(k, l) - kl_time) + E_ijij = E_ijij / len(ij) + E_ijik = E_ijik / len(ijk) + E_ijkl = E_ijkl / len(ijkl) + return E_ijij - 2 * E_ijik + E_ijkl + + +def compute_Dz(x, y, ij, ijk, ijkl): + E_ijik = 0 + E_ijkl = 0 + for i, j, k in ijk: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + ij_time = (i_time + j_time) / 2 + ik_time = (i_time + k_time) / 2 + E_ijik += (x.tmrca(i, j) - ij_time) * (y.tmrca(i, k) - ik_time) + for i, j, k, l in ijkl: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + l_time = x.time(l) + ij_time = (i_time + j_time) / 2 + kl_time = (k_time + l_time) / 2 + E_ijkl += (x.tmrca(i, j) - ij_time) * (y.tmrca(k, l) - kl_time) + E_ijik = E_ijik / len(ijk) + E_ijkl = E_ijkl / len(ijkl) + return 4 * (E_ijik - E_ijkl) + + +def compute_pi2(x, y, ij, ijk, ijkl): + E_ijkl = 0 + for i, j, k, l in ijkl: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + l_time = x.time(l) + ij_time = (i_time + j_time) / 2 + kl_time = (k_time + l_time) / 2 + E_ijkl += (x.tmrca(i, j) - ij_time) * (y.tmrca(k, l) - kl_time) + E_ijkl = E_ijkl / len(ijkl) + return E_ijkl + + +def combine(samples): + # All combinations where i != j + ij = list(combinations(samples, 2)) + # All combinations where i != {j,k} and j != k + ijk = [ + (i, j, k) + for i, j, k in product(samples, repeat=3) + if i != k and i != j and j != k + ] + # All combinations where i != {k,l} and j != {k,l} + ijkl = [ + (i, j, samples[k], samples[l]) + for i, j in combinations(samples, 2) + for k in range(len(samples)) + for l in range(k + 1, len(samples)) # noqa: E741 + if i != samples[k] and j != samples[k] and samples[l] != i and samples[l] != j + ] + return ij, ijk, ijkl + + +def naive_matrix(ts, stat_func): + """Compute a tree x tree LD matrix for a given tree sequence and two-locus + statistic. This produces a matrix of LD that is generated from the + covariance in gene genealogies, as described in McVean 2002. + + :param ts: Tree sequence to gather data from. + :param stat_func: Function to compute a two-locus statistic from two + materialized trees and sample combinations. + :returns: Pairwise branch LD matrix for an entire tree sequence. + """ + result = np.zeros((ts.num_trees, ts.num_trees), dtype=np.float64) + # These stats require at least 4 samples in the tree + if ts.num_samples < 4: + return result + ij, ijk, ijkl = combine(ts.samples()) + for i, j in combinations_with_replacement(range(ts.num_trees), 2): + val = stat_func(ts.at_index(i), ts.at_index(j), ij, ijk, ijkl) + result[i, j] = val + tri_idx = np.tril_indices(len(result), k=-1) + result[tri_idx] = result.T[tri_idx] + return result + + +@pytest.mark.parametrize( + "ts", + [ + ts + for ts in get_example_tree_sequences() + # no_samples and empty_ts aren't handled here. + if ts.id + not in { + "no_samples", + "empty_ts", + # These skipped tests are too slow for our current naive prototype + "bottleneck_n=100_mutated", + "n=100_m=32_rho=0.1", + "n=100_m=32_rho=0.5", + # These ones fail for our naive prototype because + # some samples do not have a mrca + "all_fields", + "decapitate", + "decapitate_recomb", + "empty_tree", + "gap_0", + "gap_0.1", + "gap_0.5", + "gap_0.75", + "gap_at_end", + } + ], +) +@pytest.mark.parametrize( + "stat,stat_func", + zip( + ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], + [compute_D2, compute_Dz, compute_pi2], + ), +) +def test_branch_ld_matrix(ts, stat, stat_func): + np.testing.assert_array_almost_equal( + ld_matrix(ts, stat=stat, mode="branch"), naive_matrix(ts, stat_func) + )