Skip to content

Commit

Permalink
Backward alg working
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 14, 2023
1 parent fcd9c10 commit 15125ff
Showing 1 changed file with 54 additions and 24 deletions.
78 changes: 54 additions & 24 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,22 @@ def compute(u, parent_state):
if T_parent[j] != -1:
self.N[T_parent[j]] -= self.N[j]

def update_tree(self):
def update_tree(self, direction=tskit.FORWARD):
"""
Update the internal data structures to move on to the next tree.
"""
parent = self.parent
T_index = self.T_index
T = self.T
self.tree_pos.next()
if direction == tskit.FORWARD:
self.tree_pos.next()
else:
self.tree_pos.prev()
assert self.tree_pos.index == self.tree.index

for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop):
for j in range(
self.tree_pos.out_range.start, self.tree_pos.out_range.stop, direction
):
e = self.tree_pos.out_range.order[j]
edge = self.ts.edge(e)
u = edge.child
Expand All @@ -258,7 +263,9 @@ def update_tree(self):
)
parent[edge.child] = -1

for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop):
for j in range(
self.tree_pos.in_range.start, self.tree_pos.in_range.stop, direction
):
e = self.tree_pos.in_range.order[j]
edge = self.ts.edge(e)
parent[edge.child] = edge.parent
Expand Down Expand Up @@ -336,8 +343,7 @@ def update_probabilities(self, site, haplotype_state):
for mutation in site.mutations:
allelic_state[mutation.node] = -1

def process_site(self, site, haplotype_state, forwards=True):
# Forwards algorithm, or forwards pass in Viterbi
def process_site(self, site, haplotype_state):
self.update_probabilities(site, haplotype_state)
self.compress()
s = self.compute_normalisation_factor()
Expand Down Expand Up @@ -593,7 +599,6 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

# Note node only used in Viterbi
def compute_next_probability(self, site_id, p_next, is_match, node):
p_e = self.compute_emission_proba(site_id, is_match)
return p_next * p_e
Expand All @@ -617,6 +622,7 @@ def process_site(self, site, haplotype_state):
st.value = round(st.value, self.precision)

def run(self, h):
# NOTE value=1 is the difference to the standard code for fwd and vit
self.tree.clear()
for u in self.ts.samples():
self.T_index[u] = len(self.T)
Expand Down Expand Up @@ -661,20 +667,37 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

# Note node only used in Viterbi
def compute_next_probability(self, site_id, p_next, is_match, node):
p_e = self.compute_emission_proba(site_id, is_match)
return p_next * p_e

def process_site(self, site, haplotype_state):
self.output.store_site(
site.id,
self.output.normalisation_factor[site.id],
[(st.tree_node, st.value) for st in self.T],
)
self.update_probabilities(site, haplotype_state)
self.compress()
b_last_sum = self.compute_normalisation_factor()
s = self.output.normalisation_factor[site.id]
for st in self.T:
if st.tree_node != tskit.NULL:
st.value = (self.rho[site.id] / self.ts.num_samples) * b_last_sum + (
1 - self.rho[site.id]
) * st.value
st.value /= s
st.value = round(st.value, self.precision)

def run(self, h):
self.tree.clear()
for u in self.ts.samples():
self.T_index[u] = len(self.T)
self.T.append(ValueTransition(tree_node=u, value=1))
while self.tree.next():
self.update_tree()
for site in self.tree.sites():
self.process_site(site, h[site.id], forwards=False)
while self.tree.prev():
self.update_tree(direction=tskit.REVERSE)
for site in reversed(list(self.tree.sites())):
self.process_site(site, h[site.id])
return self.output


Expand Down Expand Up @@ -1071,9 +1094,19 @@ def verify(self, ts):
np.flip(c_f.normalisation_factor),
)
B_tree = np.flip(c_b.decode(), axis=0)
c_b = ls_backward_tree(
s[0, :],
ts_check,
r,
mu,
c_f.normalisation_factor,
)
B_tree2 = c_b.decode()

F_tree = c_f.decode()

self.assertAllClose(B, B_tree)
self.assertAllClose(B, B_tree2)
self.assertAllClose(F, F_tree)
self.assertAllClose(ll, ll_tree)

Expand Down Expand Up @@ -1235,18 +1268,15 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
B_tree = np.flip(backward_cm.decode(), axis=0)
nt.assert_array_almost_equal(B, B_tree)

# backward_cm2 = ls_backward_tree(
# h,
# ts,
# recombination,
# mutation,
# forward_cm.normalisation_factor,
# )
# print()
# print(backward_cm2.decode())
# print()
# print(B_tree)
# # B_tree = np.flip(backward_cm.decode(), axis=0)
backward_cm2 = ls_backward_tree(
h,
ts,
recombination,
mutation,
forward_cm.normalisation_factor,
)
B_tree = backward_cm2.decode()
nt.assert_array_almost_equal(B, B_tree)


def add_unique_sample_mutations(ts, start=0):
Expand Down

0 comments on commit 15125ff

Please sign in to comment.