From 5c6c59fb26a6d1f8c9284b5bec376b32816e9943 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 14 Jul 2023 15:00:25 +0100 Subject: [PATCH] Remove the mirrored version of backwards alg --- python/tests/test_haplotype_matching.py | 368 +++++++++--------------- 1 file changed, 133 insertions(+), 235 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 353fc5db93..07157be6bd 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -406,136 +406,6 @@ def compute_next_probability(self, site_id, p_last, is_match, node): raise NotImplementedError() -class CompressedMatrix: - """ - Class representing a num_samples x num_sites matrix compressed by a - tree sequence. Each site is represented by a set of (node, value) - pairs, which act as "mutations", i.e., any sample that descends - from a particular node will inherit that value (unless any other - values are on the path). - """ - - def __init__(self, ts, normalisation_factor=None): - self.ts = ts - self.num_sites = ts.num_sites - self.num_samples = ts.num_samples - self.value_transitions = [None for _ in range(self.num_sites)] - if normalisation_factor is None: - self.normalisation_factor = np.zeros(self.num_sites) - else: - self.normalisation_factor = normalisation_factor - assert len(self.normalisation_factor) == self.num_sites - - def store_site(self, site, normalisation_factor, value_transitions): - self.normalisation_factor[site] = normalisation_factor - self.value_transitions[site] = value_transitions - - # Expose the same API as the low-level classes - - @property - def num_transitions(self): - a = [len(self.value_transitions[j]) for j in range(self.num_sites)] - return np.array(a, dtype=np.int32) - - def get_site(self, site): - return self.value_transitions[site] - - def decode(self): - """ - Decodes the tree encoding of the values into an explicit - matrix. - """ - A = np.zeros((self.num_sites, self.num_samples)) - for tree in self.ts.trees(): - for site in tree.sites(): - f = dict(self.value_transitions[site.id]) - for j, u in enumerate(self.ts.samples()): - while u not in f: - u = tree.parent(u) - A[site.id, j] = f[u] - return A - - -class ViterbiMatrix(CompressedMatrix): - """ - Class representing the compressed Viterbi matrix. - """ - - def __init__(self, ts): - super().__init__(ts) - # Tuple containing the site, the node in the tree, and whether - # recombination is required - self.recombination_required = [(-1, 0, False)] - - def add_recombination_required(self, site, node, required): - self.recombination_required.append((site, node, required)) - - def choose_sample(self, site_id, tree): - max_value = -1 - u = -1 - for node, value in self.value_transitions[site_id]: - if value > max_value: - max_value = value - u = node - assert u != -1 - - transition_nodes = [u for (u, _) in self.value_transitions[site_id]] - while not tree.is_sample(u): - for v in tree.children(u): - if v not in transition_nodes: - u = v - break - else: - raise AssertionError("could not find path") - return u - - def traceback(self): - # Run the traceback. - m = self.ts.num_sites - match = np.zeros(m, dtype=int) - recombination_tree = np.zeros(self.ts.num_nodes, dtype=int) - 1 - tree = tskit.Tree(self.ts) - tree.last() - current_node = -1 - - rr_index = len(self.recombination_required) - 1 - for site in reversed(self.ts.sites()): - while tree.interval.left > site.position: - tree.prev() - assert tree.interval.left <= site.position < tree.interval.right - - # Fill in the recombination tree - j = rr_index - while self.recombination_required[j][0] == site.id: - u, required = self.recombination_required[j][1:] - recombination_tree[u] = required - j -= 1 - - if current_node == -1: - current_node = self.choose_sample(site.id, tree) - match[site.id] = current_node - - # Now traverse up the tree from the current node. The first marked node - # we meet tells us whether we need to recombine. - u = current_node - while u != -1 and recombination_tree[u] == -1: - u = tree.parent(u) - - assert u != -1 - if recombination_tree[u] == 1: - # Need to switch at the next site. - current_node = -1 - # Reset the nodes in the recombination tree. - j = rr_index - while self.recombination_required[j][0] == site.id: - u, required = self.recombination_required[j][1:] - recombination_tree[u] = -1 - j -= 1 - rr_index = j - - return match - - class ForwardAlgorithm(LsHmmAlgorithm): """ The Li and Stephens forward algorithm. @@ -629,70 +499,6 @@ def run(self, h): return self.output -class MirroredBackwardAlgorithm(LsHmmAlgorithm): - """Runs the Li and Stephens backward algorithm.""" - - def __init__( - self, - ts, - rho, - mu, - alleles, - n_alleles, - normalisation_factor, - scale_mutation=False, - precision=10, - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) - self.output = CompressedMatrix(ts, normalisation_factor) - - def compute_normalisation_factor(self): - s = 0 - for j, st in enumerate(self.T): - assert st.tree_node != tskit.NULL - assert self.N[j] > 0 - s += self.N[j] * st.value - return s - - def compute_next_probability(self, site_id, p_next, is_match, node): - p_e = self.compute_emission_proba(site_id, is_match) - return p_next * p_e - - def process_site(self, site, haplotype_state): - self.output.store_site( - site.id, - self.output.normalisation_factor[site.id], - [(st.tree_node, st.value) for st in self.T], - ) - self.update_probabilities(site, haplotype_state) - self.compress() - b_last_sum = self.compute_normalisation_factor() - s = self.output.normalisation_factor[site.id] - for st in self.T: - if st.tree_node != tskit.NULL: - st.value = (self.rho[site.id] / self.ts.num_samples) * b_last_sum + ( - 1 - self.rho[site.id] - ) * st.value - st.value /= s - st.value = round(st.value, self.precision) - - def run(self, h): - self.initialise(value=1) - while self.tree.next(): - self.update_tree() - for site in self.tree.sites(): - self.process_site(site, h[site.id]) - return self.output - - class ViterbiAlgorithm(LsHmmAlgorithm): """ Runs the Li and Stephens Viterbi algorithm. @@ -742,6 +548,136 @@ def compute_next_probability(self, site_id, p_last, is_match, node): return p_t * p_e +class CompressedMatrix: + """ + Class representing a num_samples x num_sites matrix compressed by a + tree sequence. Each site is represented by a set of (node, value) + pairs, which act as "mutations", i.e., any sample that descends + from a particular node will inherit that value (unless any other + values are on the path). + """ + + def __init__(self, ts, normalisation_factor=None): + self.ts = ts + self.num_sites = ts.num_sites + self.num_samples = ts.num_samples + self.value_transitions = [None for _ in range(self.num_sites)] + if normalisation_factor is None: + self.normalisation_factor = np.zeros(self.num_sites) + else: + self.normalisation_factor = normalisation_factor + assert len(self.normalisation_factor) == self.num_sites + + def store_site(self, site, normalisation_factor, value_transitions): + self.normalisation_factor[site] = normalisation_factor + self.value_transitions[site] = value_transitions + + # Expose the same API as the low-level classes + + @property + def num_transitions(self): + a = [len(self.value_transitions[j]) for j in range(self.num_sites)] + return np.array(a, dtype=np.int32) + + def get_site(self, site): + return self.value_transitions[site] + + def decode(self): + """ + Decodes the tree encoding of the values into an explicit + matrix. + """ + A = np.zeros((self.num_sites, self.num_samples)) + for tree in self.ts.trees(): + for site in tree.sites(): + f = dict(self.value_transitions[site.id]) + for j, u in enumerate(self.ts.samples()): + while u not in f: + u = tree.parent(u) + A[site.id, j] = f[u] + return A + + +class ViterbiMatrix(CompressedMatrix): + """ + Class representing the compressed Viterbi matrix. + """ + + def __init__(self, ts): + super().__init__(ts) + # Tuple containing the site, the node in the tree, and whether + # recombination is required + self.recombination_required = [(-1, 0, False)] + + def add_recombination_required(self, site, node, required): + self.recombination_required.append((site, node, required)) + + def choose_sample(self, site_id, tree): + max_value = -1 + u = -1 + for node, value in self.value_transitions[site_id]: + if value > max_value: + max_value = value + u = node + assert u != -1 + + transition_nodes = [u for (u, _) in self.value_transitions[site_id]] + while not tree.is_sample(u): + for v in tree.children(u): + if v not in transition_nodes: + u = v + break + else: + raise AssertionError("could not find path") + return u + + def traceback(self): + # Run the traceback. + m = self.ts.num_sites + match = np.zeros(m, dtype=int) + recombination_tree = np.zeros(self.ts.num_nodes, dtype=int) - 1 + tree = tskit.Tree(self.ts) + tree.last() + current_node = -1 + + rr_index = len(self.recombination_required) - 1 + for site in reversed(self.ts.sites()): + while tree.interval.left > site.position: + tree.prev() + assert tree.interval.left <= site.position < tree.interval.right + + # Fill in the recombination tree + j = rr_index + while self.recombination_required[j][0] == site.id: + u, required = self.recombination_required[j][1:] + recombination_tree[u] = required + j -= 1 + + if current_node == -1: + current_node = self.choose_sample(site.id, tree) + match[site.id] = current_node + + # Now traverse up the tree from the current node. The first marked node + # we meet tells us whether we need to recombine. + u = current_node + while u != -1 and recombination_tree[u] == -1: + u = tree.parent(u) + + assert u != -1 + if recombination_tree[u] == 1: + # Need to switch at the next site. + current_node = -1 + # Reset the nodes in the recombination tree. + j = rr_index + while self.recombination_required[j][0] == site.id: + u, required = self.recombination_required[j][1:] + recombination_tree[u] = -1 + j -= 1 + rr_index = j + + return match + + def get_site_alleles(ts, h, alleles): if alleles is None: n_alleles = np.int8( @@ -781,22 +717,6 @@ def ls_forward_tree( return fa.run(h) -def ls_backward_tree_mirrored( - h, ts_mirror, rho, mu, normalisation_factor, precision=30, alleles=None -): - alleles, n_alleles = get_site_alleles(ts_mirror, h, alleles) - ba = MirroredBackwardAlgorithm( - ts_mirror, - rho, - mu, - alleles, - n_alleles, - normalisation_factor, - precision=precision, - ) - return ba.run(h) - - def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles=None): alleles, n_alleles = get_site_alleles(ts, h, alleles) ba = BackwardAlgorithm( @@ -1076,16 +996,6 @@ def verify(self, ts): c_f = ls_forward_tree(s[0, :], ts_check, r, mu) ll_tree = np.sum(np.log10(c_f.normalisation_factor)) - ts_check_mirror = mirror_coordinates(ts_check) - r_flip = np.flip(r) - c_b = ls_backward_tree_mirrored( - np.flip(s[0, :]), - ts_check_mirror, - r_flip, - np.flip(mu), - np.flip(c_f.normalisation_factor), - ) - B_tree = np.flip(c_b.decode(), axis=0) c_b = ls_backward_tree( s[0, :], ts_check, @@ -1093,12 +1003,11 @@ def verify(self, ts): mu, c_f.normalisation_factor, ) - B_tree2 = c_b.decode() + B_tree = c_b.decode() F_tree = c_f.decode() self.assertAllClose(B, B_tree) - self.assertAllClose(B, B_tree2) self.assertAllClose(F, F_tree) self.assertAllClose(ll, ll_tree) @@ -1249,25 +1158,14 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): scale_mutation_based_on_n_alleles=False, ) - ts_mirror = mirror_coordinates(ts) - backward_cm = ls_backward_tree_mirrored( - np.flip(h), - ts_mirror, - np.flip(recombination), - np.flip(mutation), - np.flip(forward_cm.normalisation_factor), - ) - B_tree = np.flip(backward_cm.decode(), axis=0) - nt.assert_array_almost_equal(B, B_tree) - - backward_cm2 = ls_backward_tree( + backward_cm = ls_backward_tree( h, ts, recombination, mutation, forward_cm.normalisation_factor, ) - B_tree = backward_cm2.decode() + B_tree = backward_cm.decode() nt.assert_array_almost_equal(B, B_tree)