diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index bb98d5bfd3..353fc5db93 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -336,6 +336,7 @@ def update_probabilities(self, site, haplotype_state): match = ( haplotype_state == MISSING or haplotype_state == allelic_state[v] ) + # Note that the node u is used only by Viterbi st.value = self.compute_next_probability(site.id, st.value, match, u) # Unset the states @@ -383,12 +384,15 @@ def compute_emission_proba(self, site_id, is_match): p_e = mu / (n_alleles - 1) return p_e - def run(self, h): - n = self.ts.num_samples + def initialise(self, value): self.tree.clear() for u in self.ts.samples(): self.T_index[u] = len(self.T) - self.T.append(ValueTransition(tree_node=u, value=1 / n)) + self.T.append(ValueTransition(tree_node=u, value=value)) + + def run(self, h): + n = self.ts.num_samples + self.initialise(1 / n) while self.tree.next(): self.update_tree() for site in self.tree.sites(): @@ -533,7 +537,9 @@ def traceback(self): class ForwardAlgorithm(LsHmmAlgorithm): - """Runs the Li and Stephens forward algorithm.""" + """ + The Li and Stephens forward algorithm. + """ def __init__( self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 @@ -557,7 +563,6 @@ def compute_normalisation_factor(self): s += self.N[j] * st.value return s - # Note node only used in Viterbi def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] n = self.ts.num_samples @@ -566,8 +571,10 @@ def compute_next_probability(self, site_id, p_last, is_match, node): return p_t * p_e -class MirroredBackwardAlgorithm(LsHmmAlgorithm): - """Runs the Li and Stephens backward algorithm.""" +class BackwardAlgorithm(ForwardAlgorithm): + """ + The Li and Stephens backward algorithm. + """ def __init__( self, @@ -591,50 +598,38 @@ def __init__( ) self.output = CompressedMatrix(ts, normalisation_factor) - def compute_normalisation_factor(self): - s = 0 - for j, st in enumerate(self.T): - assert st.tree_node != tskit.NULL - assert self.N[j] > 0 - s += self.N[j] * st.value - return s - def compute_next_probability(self, site_id, p_next, is_match, node): p_e = self.compute_emission_proba(site_id, is_match) return p_next * p_e def process_site(self, site, haplotype_state): + s = self.output.normalisation_factor[site.id] self.output.store_site( site.id, - self.output.normalisation_factor[site.id], + s, [(st.tree_node, st.value) for st in self.T], ) self.update_probabilities(site, haplotype_state) self.compress() b_last_sum = self.compute_normalisation_factor() - s = self.output.normalisation_factor[site.id] + n = self.ts.num_samples + rho = self.rho[site.id] for st in self.T: if st.tree_node != tskit.NULL: - st.value = (self.rho[site.id] / self.ts.num_samples) * b_last_sum + ( - 1 - self.rho[site.id] - ) * st.value + st.value = rho * b_last_sum / n + (1 - rho) * st.value st.value /= s st.value = round(st.value, self.precision) def run(self, h): - # NOTE value=1 is the difference to the standard code for fwd and vit - self.tree.clear() - for u in self.ts.samples(): - self.T_index[u] = len(self.T) - self.T.append(ValueTransition(tree_node=u, value=1)) - while self.tree.next(): - self.update_tree() - for site in self.tree.sites(): + self.initialise(value=1) + while self.tree.prev(): + self.update_tree(direction=tskit.REVERSE) + for site in reversed(list(self.tree.sites())): self.process_site(site, h[site.id]) return self.output -class BackwardAlgorithm(LsHmmAlgorithm): +class MirroredBackwardAlgorithm(LsHmmAlgorithm): """Runs the Li and Stephens backward algorithm.""" def __init__( @@ -690,13 +685,10 @@ def process_site(self, site, haplotype_state): st.value = round(st.value, self.precision) def run(self, h): - self.tree.clear() - for u in self.ts.samples(): - self.T_index[u] = len(self.T) - self.T.append(ValueTransition(tree_node=u, value=1)) - while self.tree.prev(): - self.update_tree(direction=tskit.REVERSE) - for site in reversed(list(self.tree.sites())): + self.initialise(value=1) + while self.tree.next(): + self.update_tree() + for site in self.tree.sites(): self.process_site(site, h[site.id]) return self.output