Skip to content

Commit

Permalink
Reasonably well factored Backward alg
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 14, 2023
1 parent 15125ff commit 21c80b0
Showing 1 changed file with 28 additions and 36 deletions.
64 changes: 28 additions & 36 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def update_probabilities(self, site, haplotype_state):
match = (
haplotype_state == MISSING or haplotype_state == allelic_state[v]
)
# Note that the node u is used only by Viterbi
st.value = self.compute_next_probability(site.id, st.value, match, u)

# Unset the states
Expand Down Expand Up @@ -383,12 +384,15 @@ def compute_emission_proba(self, site_id, is_match):
p_e = mu / (n_alleles - 1)
return p_e

def run(self, h):
n = self.ts.num_samples
def initialise(self, value):
self.tree.clear()
for u in self.ts.samples():
self.T_index[u] = len(self.T)
self.T.append(ValueTransition(tree_node=u, value=1 / n))
self.T.append(ValueTransition(tree_node=u, value=value))

def run(self, h):
n = self.ts.num_samples
self.initialise(1 / n)
while self.tree.next():
self.update_tree()
for site in self.tree.sites():
Expand Down Expand Up @@ -533,7 +537,9 @@ def traceback(self):


class ForwardAlgorithm(LsHmmAlgorithm):
"""Runs the Li and Stephens forward algorithm."""
"""
The Li and Stephens forward algorithm.
"""

def __init__(
self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10
Expand All @@ -557,7 +563,6 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

# Note node only used in Viterbi
def compute_next_probability(self, site_id, p_last, is_match, node):
rho = self.rho[site_id]
n = self.ts.num_samples
Expand All @@ -566,8 +571,10 @@ def compute_next_probability(self, site_id, p_last, is_match, node):
return p_t * p_e


class MirroredBackwardAlgorithm(LsHmmAlgorithm):
"""Runs the Li and Stephens backward algorithm."""
class BackwardAlgorithm(ForwardAlgorithm):
"""
The Li and Stephens backward algorithm.
"""

def __init__(
self,
Expand All @@ -591,50 +598,38 @@ def __init__(
)
self.output = CompressedMatrix(ts, normalisation_factor)

def compute_normalisation_factor(self):
s = 0
for j, st in enumerate(self.T):
assert st.tree_node != tskit.NULL
assert self.N[j] > 0
s += self.N[j] * st.value
return s

def compute_next_probability(self, site_id, p_next, is_match, node):
p_e = self.compute_emission_proba(site_id, is_match)
return p_next * p_e

def process_site(self, site, haplotype_state):
s = self.output.normalisation_factor[site.id]
self.output.store_site(
site.id,
self.output.normalisation_factor[site.id],
s,
[(st.tree_node, st.value) for st in self.T],
)
self.update_probabilities(site, haplotype_state)
self.compress()
b_last_sum = self.compute_normalisation_factor()
s = self.output.normalisation_factor[site.id]
n = self.ts.num_samples
rho = self.rho[site.id]
for st in self.T:
if st.tree_node != tskit.NULL:
st.value = (self.rho[site.id] / self.ts.num_samples) * b_last_sum + (
1 - self.rho[site.id]
) * st.value
st.value = rho * b_last_sum / n + (1 - rho) * st.value
st.value /= s
st.value = round(st.value, self.precision)

def run(self, h):
# NOTE value=1 is the difference to the standard code for fwd and vit
self.tree.clear()
for u in self.ts.samples():
self.T_index[u] = len(self.T)
self.T.append(ValueTransition(tree_node=u, value=1))
while self.tree.next():
self.update_tree()
for site in self.tree.sites():
self.initialise(value=1)
while self.tree.prev():
self.update_tree(direction=tskit.REVERSE)
for site in reversed(list(self.tree.sites())):
self.process_site(site, h[site.id])
return self.output


class BackwardAlgorithm(LsHmmAlgorithm):
class MirroredBackwardAlgorithm(LsHmmAlgorithm):
"""Runs the Li and Stephens backward algorithm."""

def __init__(
Expand Down Expand Up @@ -690,13 +685,10 @@ def process_site(self, site, haplotype_state):
st.value = round(st.value, self.precision)

def run(self, h):
self.tree.clear()
for u in self.ts.samples():
self.T_index[u] = len(self.T)
self.T.append(ValueTransition(tree_node=u, value=1))
while self.tree.prev():
self.update_tree(direction=tskit.REVERSE)
for site in reversed(list(self.tree.sites())):
self.initialise(value=1)
while self.tree.next():
self.update_tree()
for site in self.tree.sites():
self.process_site(site, h[site.id])
return self.output

Expand Down

0 comments on commit 21c80b0

Please sign in to comment.