From 1761e0569dd93a9830c720e6f73d720923da04b3 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 21 Mar 2024 10:08:20 -0500 Subject: [PATCH] Adds a py prototype for two-locus branch stats Currently, this algorithm creates a matrix of LD, performing a pairwise comparison of all trees in the tree sequence. This implementation lacks windows/positions, sample sets and polarisation. The outputs of the code produce results in units of branch length, needing to be multiplied by mu^2 or divided by product of the total branch length of the two trees. This algorithm works by keeping a running sum of the statistic between two trees, updating each time we encounter a branch addition or removal. The tricky part is that we have to remove or add LD contributed by samples that already existed or that will remain under a given node after the addition/removal of branches. We include a validation against the original formulation of this problem, by including an implementation that was described in McVean 2002. The original formulation computing the covariance of tMRCAs of 2, 3, and 4 samples of individuals on the trees in question. This implementation has several limitations 1) it is very slow and 2) it does not work for trees that are decapitated, as certain samples do not have MRCAs. --- python/tests/test_ld_matrix.py | 529 +++++++++++++++++++++++++++++++-- 1 file changed, 509 insertions(+), 20 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 58b52debb8..dbc3695492 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -24,8 +24,12 @@ """ import contextlib import io +from copy import deepcopy +from dataclasses import dataclass +from itertools import combinations from itertools import combinations_with_replacement from itertools import permutations +from itertools import product from typing import Any from typing import Callable from typing import Dict @@ -38,6 +42,7 @@ import pytest import tskit +from tests import tsutil from tests.test_highlevel import get_example_tree_sequences @@ -603,6 +608,99 @@ def two_site_count_stat( return result +def two_branch_count_stat( + ts: tskit.TreeSequence, + func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], + norm_func: Callable[ # No need for norm func, biallelic + [int, np.ndarray, int, int, np.ndarray, Dict[str, Any]], None + ], + num_sample_sets: int, + sample_set_sizes: np.ndarray, + sample_sets: BitSet, # TODO: implement sample sets + sample_index_map: np.ndarray, + row_sites: np.ndarray, # TODO: positions + col_sites: np.ndarray, # TODO: positions + polarised: bool, # TODO: polarisation +) -> np.ndarray: + """ + Compute a tree X tree LD matrix by walking along the tree sequence and + computing haplotype counts. This method incrementally adds and removes + branches from a tree sequence and updates the stat based on sample additions + and removals. We bifurcate the tree with a given branch on each locus and + intersect the samples under each branch to produce haplotype counts. We + output a full LD matrix for the entire tree sequence. + + :param ts: Tree sequence to gather data from. + :param func: Function used to compute each two-locus statistic. + :param norm_func: Not (YET) applicable for branch stats: TODO? + :param num_sample_sets: Number of sample sets that we will consider. + :param sample_set_sizes: Number of samples in each sample set. + :param sample_sets: BitSet of samples to compute stats for. We will only + consider these samples in our computations, resulting + in stats that are computed on subsets of the samples + on the tree sequence. + :param row_sites: Sites contained in the rows of the output matrix. + :param col_sites: Sites contained in the columns of the output matrix. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :returns: 3D array of results, dimensions (sample_sets, row_sites, col_sites). + """ + # state_dim = len(sample_set_sizes) + # params = {"sample_set_sizes": sample_set_sizes} + result = np.zeros( + (num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64 + ) + # These stats require at least 4 samples in the tree + if ts.num_samples < 4: + return result + + # TODO: get_pos_row_col_indices? + # sites, row_idx, col_idx = get_site_row_col_indices(row_sites, col_sites) + + stat = 0 + l_state = TreeState(ts) # State is initialized at tree -1 + r_state = TreeState(ts) + tmp = None # Tmp tree will be used for swapping below. + tmp_stat = None + + # Advance the fixed (left) tree state by adding all edges. Stat will be 0 + stat, l_state = compute_stat(ts, func, stat, l_state, TreeState(ts)) + for i in range(ts.num_trees): + # Compute the stat between the left and right state, advancing the right + # state to the SAME tree that the lefthand tree represents. If we start + # at tree 0 and tree -1, we end up at tree 0 and tree 0, etc. + stat, r_state = compute_stat(ts, func, stat, l_state, r_state) + # TODO: sample sets + result[0, l_state.pos.index, r_state.pos.index] = stat + # Continue to advance the righthand tree until we hit the end of the + # tree sequence, computing the stat for the rest of the upper triangle + # of the LD matrix. + for _ in range(i, ts.num_trees - 1): + stat, r_state = compute_stat(ts, func, stat, l_state, r_state) + # TODO: sample sets + result[0, l_state.pos.index, r_state.pos.index] = stat + # After the first iteration of this loop, we store the r_state and + # stat to be used as the lefthand tree in the next iteration. + if tmp is None: + tmp_stat = stat + tmp = deepcopy(r_state) + # We store the focal tree as the starting point for the next row in the + # LD matrix. Remember that in order to compute the association between + # tree 1 and tree 1, we need the righthand state to *start* at tree 0. + r_state = deepcopy(l_state) + if tmp is not None: + l_state = deepcopy(tmp) + stat = tmp_stat + tmp = None + + # Reflect the upper triangle of the LD matrix to the lower triangle. + tril_idx = np.tril_indices(len(result[0]), k=-1) + # TODO: sample sets + tril_idx = (np.zeros(len(tril_idx[0]), dtype=int), *tril_idx) + result[tril_idx] = result[tril_idx[0], tril_idx[2], tril_idx[1]] + return result + + def sample_sets_to_bit_array( ts: tskit.TreeSequence, sample_sets: List[List[int]] ) -> Tuple[np.ndarray, np.ndarray, BitSet]: @@ -648,6 +746,7 @@ def two_locus_count_stat( summary_func, norm_func, polarised, + mode, sites=None, sample_sets=None, ): @@ -692,18 +791,34 @@ def two_locus_count_stat( sample_index_map, ss_sizes, ss_bits = sample_sets_to_bit_array(ts, sample_sets) - result = two_site_count_stat( - ts, - summary_func, - norm_func, - len(ss_sizes), - ss_sizes, - ss_bits, - sample_index_map, - row_sites, - col_sites, - polarised, - ) + if mode == "site": + result = two_site_count_stat( + ts, + summary_func, + norm_func, + len(ss_sizes), + ss_sizes, + ss_bits, + sample_index_map, + row_sites, + col_sites, + polarised, + ) + elif mode == "branch": + result = two_branch_count_stat( + ts, + summary_func, + None, + 1, + [ts.num_samples], + None, + None, + range(ts.num_trees), + range(ts.num_trees), + False, + ) + else: + raise ValueError(f"Unknown mode: {mode}") # If there is one sample set, return a 2d numpy array of row/site LD if len(sample_sets) == 1: @@ -848,6 +963,48 @@ def pi2_summary_func( result[k] = p_A * (1 - p_A) * p_B * (1 - p_B) +# Unbiased estimators of pi2, dz, and d2. These are derived in Ragsdale 2019 +# (https://doi.org/10.1093/molbev/msz265) and can be used in place of the method +# outlined by McVean 2002. The reason for using haplotype counts in the branch +# methods is that we can compute statistics that cannot be represented by tMRCA +# covariance. With these unbiased estimators, we still reproduce the values +# estimated with tMRCA covariance. + +# TODO: update these summary functions to have the same function signature as +# the summary functions defined above. + + +def pi2_unbiased(w_AB, w_Ab, w_aB, n): + w_ab = n - (w_AB + w_Ab + w_aB) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((w_AB + w_Ab) * (w_aB + w_ab) * (w_AB + w_aB) * (w_Ab + w_ab)) + - ((w_AB * w_ab) * (w_AB + w_ab + (3 * w_Ab) + (3 * w_aB) - 1)) + - ((w_Ab * w_aB) * (w_Ab + w_aB + (3 * w_AB) + (3 * w_ab) - 1)) + ) + + +def dz_unbiased(w_AB, w_Ab, w_aB, n): + w_ab = n - (w_AB + w_Ab + w_aB) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ( + ((w_AB * w_ab) - (w_Ab * w_aB)) + * (w_aB + w_ab - w_AB - w_Ab) + * (w_Ab + w_ab - w_AB - w_aB) + ) + - ((w_AB * w_ab) * (w_AB + w_ab - w_Ab - w_aB - 2)) + - ((w_Ab * w_aB) * (w_Ab + w_aB - w_AB - w_ab - 2)) + ) + + +def d2_unbiased(w_AB, w_Ab, w_aB, n): + w_ab = n - (w_AB + w_Ab + w_aB) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((w_aB**2) * (w_Ab - 1) * w_Ab) + + ((w_ab - 1) * w_ab * (w_AB - 1) * w_AB) + - (w_aB * w_Ab * (w_Ab + (2 * w_ab * w_AB) - 1)) + ) + + SUMMARY_FUNCS = { "r": r_summary_func, "r2": r2_summary_func, @@ -856,6 +1013,9 @@ def pi2_summary_func( "D_prime": D_prime_summary_func, "pi2": pi2_summary_func, "Dz": Dz_summary_func, + "d2_unbiased": d2_unbiased, + "dz_unbiased": dz_unbiased, + "pi2_unbiased": pi2_unbiased, } NORM_METHOD = { @@ -866,6 +1026,9 @@ def pi2_summary_func( pi2_summary_func: norm_total_weighted, r_summary_func: norm_total_weighted, r2_summary_func: norm_hap_weighted, + d2_unbiased: None, + dz_unbiased: None, + pi2_unbiased: None, } POLARIZATION = { @@ -876,21 +1039,20 @@ def pi2_summary_func( pi2_summary_func: False, r_summary_func: True, r2_summary_func: False, + d2_unbiased: False, + dz_unbiased: False, + pi2_unbiased: False, } -def ld_matrix( - ts, - sample_sets=None, - sites=None, - stat="r2", -): +def ld_matrix(ts, sample_sets=None, sites=None, stat="r2", mode="site"): summary_func = SUMMARY_FUNCS[stat] return two_locus_count_stat( ts, summary_func, NORM_METHOD[summary_func], POLARIZATION[summary_func], + mode, sites=sites, sample_sets=sample_sets, ) @@ -1049,7 +1211,10 @@ def test_compare_to_ld_calculator(): np.testing.assert_array_almost_equal(ld_calc.get_r2_matrix(), ts.ld_matrix()) -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +@pytest.mark.parametrize( + "stat", + sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), +) def test_multiallelic_with_back_mutation(stat): ts = msprime.sim_ancestry( samples=4, recombination_rate=0.2, sequence_length=10, random_seed=1 @@ -1066,7 +1231,11 @@ def test_multiallelic_with_back_mutation(stat): if ts.id not in {"no_samples", "empty_ts"} ], ) -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +# TODO: port unbiased summary functions +@pytest.mark.parametrize( + "stat", + sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), +) def test_ld_matrix(ts, stat): np.testing.assert_array_equal(ld_matrix(ts, stat=stat), ts.ld_matrix(stat=stat)) @@ -1092,3 +1261,323 @@ def test_input_validation(): ts.ld_matrix(sites=[[1, 2], [2, 3], [3, 4]]) with pytest.raises(ValueError, match="must be a length 1 or 2 list"): ts.ld_matrix(sites=[]) + + +@dataclass +class TreeState: + """ + Class for storing tree state from one iteration to the next. This object + enables easy copying of the state for computing a matrix. + """ + + pos: tsutil.TreePosition # current position in the tree sequence + parent: np.ndarray # parent node of a given node (connected by an edge) + branch_len: np.ndarray # length of the branch above a particular child node + node_samples: BitSet # samples that exist under a given node, this is a + # bitset with a row for each node. + + def __init__(self, ts): + self.pos = tsutil.TreePosition(ts) + self.parent = -np.ones(ts.num_nodes, dtype=np.int64) + self.branch_len = np.zeros(ts.num_nodes, dtype=np.float64) + self.node_samples = BitSet(ts.num_samples, ts.num_nodes) + for s in ts.samples(): + self.node_samples.add(s, s) + + +def compute_stat_update(c, child_samples, A_state, B_state, stat_func, num_samples): + """Compute an update to the two-locus statistic for a single subset of the + tree being modified, relative to all subsets of the fixed tree. We perform + this operation for all samples edge being modified. For subsequent parent + nodes, we update the statistic by removing the existing contribution after + adding in the update contribution. + + i.e. if we're adding two samples ({3, 4}) to a node, if the parent node + contains {1, 2}, we first add the statistic for {1, 2, 3, 4}, then + subtract the stat for {1, 2}. + + :param c: Child node of the edge we're modifying + :param child_samples: Samples under the edge being added/removed + :param A_state: State for the tree contributing to the A samples (fixed) + :param A_state: State for the tree contributing to the B samples (modified) + :param stat_func: Function used to compute the two-locus statistic + :param num_samples: Number of samples in the tree sequence + :returns: The change to the statistic, given a single edge update in the tree + """ + stat = 0 + b_len = B_state.branch_len[c] + if b_len == 0: + return stat + AB_samples = BitSet(num_samples, 1) + node_samples_tmp = BitSet(num_samples, 1) + + for n in np.where(A_state.branch_len > 0)[0]: + a_len = A_state.branch_len[n] + # Samples under the modified edge and the current fixed tree node are AB + A_state.node_samples.intersect(n, B_state.node_samples, c, AB_samples) + w_AB = AB_samples.count(0) + w_A = A_state.node_samples.count(n) + w_Ab = w_A - w_AB + w_aB = B_state.node_samples.count(c) - w_AB + stat += stat_func(w_AB, w_Ab, w_aB, num_samples) * a_len * b_len + + # If we've begun our walk up the parents of the current edge removal, we + # must adjust the statistic for samples that were already present before + # addition or that remain after removal. + if child_samples is not None: + node_samples_tmp.union(0, B_state.node_samples, c) + node_samples_tmp.difference(0, child_samples, 0) + AB_samples.data[:] = 0 + # Zero out the bitset so that we can reuse it + A_state.node_samples.intersect(n, node_samples_tmp, 0, AB_samples) + w_AB = AB_samples.count(0) + w_Ab = w_A - w_AB + w_aB = node_samples_tmp.count(0) - w_AB + stat -= stat_func(w_AB, w_Ab, w_aB, num_samples) * a_len * b_len + return stat + + +def compute_stat(ts, stat_func, stat, l_state, r_state): + """Step between trees in a tree sequence, updating our two-locus statistic + as we add or remove edges. Since we're computing statistics for two loci, we + have a focal tree that remains constant, and a tree that is updated to + represent the tree we're comparing to. The lefthand tree is held constant + and the righthand tree is modified. The statistic is updated as we add and + remove branches, and when we reach the point where the righthand tree is + fully updated, the statistic will have been updated to the two-locus + statistic between both trees. + + For instance, if we pass in the l_state for tree 0 and the r_state for tree + 0, we will update the r_state until r_state contains the information for + tree 1. Then, the statistic will represent the LD between tree 1 and tree 2. + + Currenty, iteration happens in the forward direction. + + :param ts: The underlying tree sequence object that we're iterating across. + :param stat_func: A function that computes the two locus statistic, given + haplotype counts. + :param stat: The two-locus statistic computed between two trees. + :param l_state: The lefthand, constant state + :param r_state: The righthand, mutated state + :returns: A tuple containing the statistic between the two trees after + branch updates and the righthand tree state. + """ + time = ts.tables.nodes.time + r_pos = r_state.pos + # Iterate the right tree position to the next tree. The rest of the r_state + # data is not valid until the end of this function. + assert r_pos.next(), "out of bounds" + + child_samples = BitSet(ts.num_samples, 1) + for e in r_pos.out_range.order[r_pos.out_range.start : r_pos.out_range.stop]: + p = r_pos.ts.edges_parent[e] + c = r_pos.ts.edges_child[e] + child_samples.data[:] = 0 + child_samples.union(0, r_state.node_samples, c) # samples removed by this edge + + # Remove the LD contributed by the samples under removed edges. When + # we walk up the tree to propagate these changes to parents of the + # removed edge, we need to add back in the LD contributed by samples + # that aren't removed. We remove samples from the parents of the removed + # branch as we propagate changes upward + in_parent = None + while p != tskit.NULL: + stat -= compute_stat_update( + c, in_parent, l_state, r_state, stat_func, ts.num_samples + ) + if in_parent is not None: + # remove samples from the parents of the branch being removed + # we remove the child node after the first iteration + r_state.node_samples.difference(c, child_samples, 0) + in_parent = child_samples + c = p + p = r_state.parent[p] + r_state.node_samples.difference(c, child_samples, 0) + + # reset to the child of the edge being removed. + c = ts.edges_child[e] + r_state.branch_len[c] = 0 + r_state.parent[c] = tskit.NULL + + for e in r_pos.in_range.order[r_pos.in_range.start : r_pos.in_range.stop]: + p = r_pos.ts.edges_parent[e] + c = r_pos.ts.edges_child[e] + child_samples.data[:] = 0 + child_samples.union(0, r_state.node_samples, c) # samples added by this edge + r_state.branch_len[c] = time[p] - time[c] + r_state.parent[c] = p + + # Add the LD contributed by the samples under added edges. When we walk + # up the tree to propagate these changes to parents of the removed edge, + # we need to remove the LD contributed by samples that were already + # there + in_parent = None + while p != tskit.NULL: + r_state.node_samples.union(p, child_samples, 0) + stat += compute_stat_update( + c, in_parent, l_state, r_state, stat_func, ts.num_samples + ) + in_parent = child_samples + c = p + p = r_state.parent[p] + + return stat, r_state + + +# What follows is an implementation of two-locus statistics as described in +# McVean 2002 (https://doi.org/10.1093/genetics/162.2.987). We compute the +# covariance between coalescent times to produce expectations of coalescent +# times between three sampling patterns of samples. These expectations can be +# compined to produce D2, Dz, and pi2. These are for testing and to demonstrate +# conceptual parity between our method and McVean's method. + + +def compute_D2(x, y, ij, ijk, ijkl): + E_ijij = 0 + E_ijik = 0 + E_ijkl = 0 + for i, j in ij: + i_time = x.time(i) + j_time = x.time(j) + avg_time = (i_time + j_time) / 2 + E_ijij += (x.tmrca(i, j) - avg_time) * (y.tmrca(i, j) - avg_time) + for i, j, k in ijk: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + ij_time = (i_time + j_time) / 2 + ik_time = (i_time + k_time) / 2 + E_ijik += (x.tmrca(i, j) - ij_time) * (y.tmrca(i, k) - ik_time) + for i, j, k, l in ijkl: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + l_time = x.time(l) + ij_time = (i_time + j_time) / 2 + kl_time = (k_time + l_time) / 2 + E_ijkl += (x.tmrca(i, j) - ij_time) * (y.tmrca(k, l) - kl_time) + E_ijij = E_ijij / len(ij) + E_ijik = E_ijik / len(ijk) + E_ijkl = E_ijkl / len(ijkl) + return E_ijij - 2 * E_ijik + E_ijkl + + +def compute_Dz(x, y, ij, ijk, ijkl): + E_ijik = 0 + E_ijkl = 0 + for i, j, k in ijk: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + ij_time = (i_time + j_time) / 2 + ik_time = (i_time + k_time) / 2 + E_ijik += (x.tmrca(i, j) - ij_time) * (y.tmrca(i, k) - ik_time) + for i, j, k, l in ijkl: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + l_time = x.time(l) + ij_time = (i_time + j_time) / 2 + kl_time = (k_time + l_time) / 2 + E_ijkl += (x.tmrca(i, j) - ij_time) * (y.tmrca(k, l) - kl_time) + E_ijik = E_ijik / len(ijk) + E_ijkl = E_ijkl / len(ijkl) + return 4 * (E_ijik - E_ijkl) + + +def compute_pi2(x, y, ij, ijk, ijkl): + E_ijkl = 0 + for i, j, k, l in ijkl: + i_time = x.time(i) + j_time = x.time(j) + k_time = x.time(k) + l_time = x.time(l) + ij_time = (i_time + j_time) / 2 + kl_time = (k_time + l_time) / 2 + E_ijkl += (x.tmrca(i, j) - ij_time) * (y.tmrca(k, l) - kl_time) + E_ijkl = E_ijkl / len(ijkl) + return E_ijkl + + +def combine(samples): + # All combinations where i != j + ij = list(combinations(samples, 2)) + # All combinations where i != {j,k} and j != k + ijk = [ + (i, j, k) + for i, j, k in product(samples, repeat=3) + if i != k and i != j and j != k + ] + # All combinations where i != {k,l} and j != {k,l} + ijkl = [ + (i, j, samples[k], samples[l]) + for i, j in combinations(samples, 2) + for k in range(len(samples)) + for l in range(k + 1, len(samples)) # noqa: E741 + if i != samples[k] and j != samples[k] and samples[l] != i and samples[l] != j + ] + return ij, ijk, ijkl + + +def naive_matrix(ts, stat_func): + """Compute a tree x tree LD matrix for a given tree sequence and two-locus + statistic. This produces a matrix of LD that is generated from the + covariance in gene genealogies, as described in McVean 2002. + + :param ts: Tree sequence to gather data from. + :param stat_func: Function to compute a two-locus statistic from two + materialized trees and sample combinations. + :returns: Pairwise branch LD matrix for an entire tree sequence. + """ + result = np.zeros((ts.num_trees, ts.num_trees), dtype=np.float64) + # These stats require at least 4 samples in the tree + if ts.num_samples < 4: + return result + ij, ijk, ijkl = combine(ts.samples()) + for i, j in combinations_with_replacement(range(ts.num_trees), 2): + val = stat_func(ts.at_index(i), ts.at_index(j), ij, ijk, ijkl) + result[i, j] = val + tri_idx = np.tril_indices(len(result), k=-1) + result[tri_idx] = result.T[tri_idx] + return result + + +@pytest.mark.parametrize( + "ts", + [ + ts + for ts in get_example_tree_sequences() + # no_samples and empty_ts aren't handled here. + if ts.id + not in { + "no_samples", + "empty_ts", + # These skipped tests are too slow for our current naive prototype + "bottleneck_n=100_mutated", + "n=100_m=32_rho=0.1", + "n=100_m=32_rho=0.5", + # These ones fail for our naive prototype because + # some samples do not have a mrca + "all_fields", + "decapitate", + "decapitate_recomb", + "empty_tree", + "gap_0", + "gap_0.1", + "gap_0.5", + "gap_0.75", + "gap_at_end", + } + ], +) +@pytest.mark.parametrize( + "stat,stat_func", + zip( + ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], + [compute_D2, compute_Dz, compute_pi2], + ), +) +def test_branch_ld_matrix(ts, stat, stat_func): + np.testing.assert_array_almost_equal( + ld_matrix(ts, stat=stat, mode="branch"), naive_matrix(ts, stat_func) + )