diff --git a/python/tests/test_haplotype_matching_fb.py b/python/tests/test_haplotype_matching_fb.py index b727d62250..b09ebcc005 100644 --- a/python/tests/test_haplotype_matching_fb.py +++ b/python/tests/test_haplotype_matching_fb.py @@ -320,7 +320,7 @@ def update_probabilities(self, site, haplotype_state): match = ( haplotype_state == MISSING or haplotype_state == allelic_state[v] ) - st.value = self.compute_next_probability(site.id, st.value, match) + st.value = self.compute_next_probability(site.id, st.value, match, u) # Unset the states allelic_state[tree.root] = -1 @@ -329,7 +329,7 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state, forwards=True): if forwards: - # Forwards algorithm + # Forwards algorithm, or forwards pass in Viterbi self.update_probabilities(site, haplotype_state) self.compress() s = self.compute_normalisation_factor() @@ -385,7 +385,7 @@ def run_backward(self, h): def compute_normalisation_factor(self): raise NotImplementedError() - def compute_next_probability(self, site_id, p_last, is_match): + def compute_next_probability(self, site_id, p_last, is_match, node): raise NotImplementedError() @@ -447,6 +447,86 @@ class BackwardMatrix(CompressedMatrix): """Class representing a compressed backward matrix.""" +class ViterbiMatrix(CompressedMatrix): + """ + Class representing the compressed Viterbi matrix. + """ + + def __init__(self, ts): + super().__init__(ts) + # Tuple containing the site, the node in the tree, and whether + # recombination is required + self.recombination_required = [(-1, 0, False)] + + def add_recombination_required(self, site, node, required): + self.recombination_required.append((site, node, required)) + + def choose_sample(self, site_id, tree): + max_value = -1 + u = -1 + for node, value in self.value_transitions[site_id]: + if value > max_value: + max_value = value + u = node + assert u != -1 + + transition_nodes = [u for (u, _) in self.value_transitions[site_id]] + while not tree.is_sample(u): + for v in tree.children(u): + if v not in transition_nodes: + u = v + break + else: + raise AssertionError("could not find path") + return u + + def traceback(self): + # Run the traceback. + m = self.ts.num_sites + match = np.zeros(m, dtype=int) + recombination_tree = np.zeros(self.ts.num_nodes, dtype=int) - 1 + tree = tskit.Tree(self.ts) + tree.last() + current_node = -1 + + rr_index = len(self.recombination_required) - 1 + for site in reversed(self.ts.sites()): + while tree.interval.left > site.position: + tree.prev() + assert tree.interval.left <= site.position < tree.interval.right + + # Fill in the recombination tree + j = rr_index + while self.recombination_required[j][0] == site.id: + u, required = self.recombination_required[j][1:] + recombination_tree[u] = required + j -= 1 + + if current_node == -1: + current_node = self.choose_sample(site.id, tree) + match[site.id] = current_node + + # Now traverse up the tree from the current node. The first marked node + # we meet tells us whether we need to recombine. + u = current_node + while u != -1 and recombination_tree[u] == -1: + u = tree.parent(u) + + assert u != -1 + if recombination_tree[u] == 1: + # Need to switch at the next site. + current_node = -1 + # Reset the nodes in the recombination tree. + j = rr_index + while self.recombination_required[j][0] == site.id: + u, required = self.recombination_required[j][1:] + recombination_tree[u] = -1 + j -= 1 + rr_index = j + + return match + + class ForwardAlgorithm(LsHmmAlgorithm): """Runs the Li and Stephens forward algorithm.""" @@ -472,7 +552,9 @@ def compute_normalisation_factor(self): s += self.N[j] * st.value return s - def compute_next_probability(self, site_id, p_last, is_match): + def compute_next_probability( + self, site_id, p_last, is_match, node + ): # Note node only used in Viterbi rho = self.rho[site_id] mu = self.mu[site_id] n = self.ts.num_samples @@ -541,7 +623,9 @@ def compute_normalisation_factor(self): s += self.N[j] * st.value return s - def compute_next_probability(self, site_id, p_next, is_match): + def compute_next_probability( + self, site_id, p_next, is_match, node + ): # Note node only used in Viterbi mu = self.mu[site_id] n_alleles = self.n_alleles[site_id] @@ -564,6 +648,82 @@ def compute_next_probability(self, site_id, p_next, is_match): return p_next * p_e +class ViterbiAlgorithm(LsHmmAlgorithm): + """ + Runs the Li and Stephens Viterbi algorithm. + """ + + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = ViterbiMatrix(ts) + + def compute_normalisation_factor(self): + max_st = ValueTransition(value=-1) + for st in self.T: + assert st.tree_node != tskit.NULL + if st.value > max_st.value: + max_st = st + if max_st.value == 0: + raise ValueError( + "Trying to match non-existent allele with zero mutation rate" + ) + return max_st.value + + def compute_next_probability(self, site_id, p_last, is_match, node): + rho = self.rho[site_id] + mu = self.mu[site_id] + n = self.ts.num_samples + n_alleles = self.n_alleles[site_id] + + p_no_recomb = p_last * (1 - rho + rho / n) + p_recomb = rho / n + recombination_required = False + if p_no_recomb > p_recomb: + p_t = p_no_recomb + else: + p_t = p_recomb + recombination_required = True + self.output.add_recombination_required(site_id, node, recombination_required) + + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + + return p_t * p_e + + def ls_forward_tree( h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False ): @@ -636,6 +796,43 @@ def ls_backward_tree( return ba.run_backward(h) +def ls_viterbi_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) + for j in range(ts.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts.num_sites) + """ + Viterbi path computation based on a tree sequence. + """ + va = ViterbiAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + return va.run_forward(h) + + class LSBase: """Superclass of Li and Stephens tests.""" @@ -788,6 +985,10 @@ class FBAlgorithmBase(LSBase): """Base for forwards backwards algorithm tests.""" +class VitAlgorithmBase(LSBase): + """Base for viterbi algoritm tests.""" + + class TestMirroringHap(FBAlgorithmBase): """Tests that mirroring the tree sequence and running forwards and backwards algorithms gives the same log-likelihood of observing the data.""" @@ -889,3 +1090,31 @@ def verify(self, ts): self.assertAllClose(B, B_tree) self.assertAllClose(F, F_tree) self.assertAllClose(ll, ll_tree) + + +class TestTreeViterbiHap(VitAlgorithmBase): + """Test that we have the same log-likelihood between tree and matrix + implementations""" + + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + path, ll = ls.viterbi( + H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False + ) + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_viterbi_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + self.assertAllClose(ll, ll_tree) + + # Now, need to ensure that the likelihood of the preferred path is + # the same as ll_tree (and ll). + path_tree = cm.traceback() + ll_check = ls.path_ll( + H, + s, + path_tree, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=False, + ) + self.assertAllClose(ll, ll_check) diff --git a/python/tests/test_haplotype_matching_viterbi.py b/python/tests/test_haplotype_matching_viterbi.py index 571b4bb337..004e01c76e 100644 --- a/python/tests/test_haplotype_matching_viterbi.py +++ b/python/tests/test_haplotype_matching_viterbi.py @@ -20,7 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """ -Python implementation of the Li and Stephens Viterbi algorithm. +Python implementation of the Li and Stephens Viterbi algorithms. """ import itertools import unittest