Skip to content

Commit

Permalink
Fixed bug in next proba expression
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 15, 2023
1 parent 430cba3 commit 82e97eb
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
21 changes: 20 additions & 1 deletion c/tskit/haplotype_matching.c
Original file line number Diff line number Diff line change
Expand Up @@ -1140,14 +1140,33 @@ tsk_ls_hmm_process_site_backward(tsk_ls_hmm_t *self, const tsk_site_t *site,
b_last_sum = self->compute_normalisation_factor(self);
for (j = 0; j < self->num_transitions; j++) {
tsk_bug_assert(T[j].tree_node != TSK_NULL);
x = T[j].value * b_last_sum / n + (1 - rho) * T[j].value;
x = rho * b_last_sum / n + (1 - rho) * T[j].value;
x /= normalisation_factor;
T[j].value = tsk_round(x, precision);
}
out:
return ret;
}

/* def process_site(self, site, haplotype_state, s): */
/* self.output.store_site( */
/* site.id, */
/* s, */
/* # We need to filter out the -1 nodes here for the first site */
/* # we examine. This is a bit of a hack */
/* [(st.tree_node, st.value) for st in self.T if st.tree_node != -1], */
/* ) */
/* self.update_probabilities(site, haplotype_state) */
/* self.compress() */
/* b_last_sum = self.compute_normalisation_factor() */
/* n = self.ts.num_samples */
/* rho = self.rho[site.id] */
/* for st in self.T: */
/* if st.tree_node != tskit.NULL: */
/* st.value = rho * b_last_sum / n + (1 - rho) * st.value */
/* st.value /= s */
/* st.value = round(st.value, self.precision) */

static int
tsk_ls_hmm_run_backward(
tsk_ls_hmm_t *self, int32_t *haplotype, const double *forward_norm)
Expand Down
31 changes: 24 additions & 7 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):


def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
# precision = 22
precision = 22
h = np.array(h).astype(np.int8)
n = ts.num_samples
assert len(h) == ts.num_sites
Expand All @@ -1143,21 +1143,38 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
recombination,
mutation,
forward_cm.normalisation_factor,
precision=precision,
)
nt.assert_array_equal(
backward_cm.normalisation_factor, forward_cm.normalisation_factor
)
B_tree = backward_cm.decode()
nt.assert_array_almost_equal(B, B_tree)

# ll_ts = ts._ll_tree_sequence
# ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision)
# cm_lib = _tskit.CompressedMatrix(ll_ts)
# ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib)
# B_lib = cm_lib.decode()
ll_ts = ts._ll_tree_sequence
ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision)
cm_lib = _tskit.CompressedMatrix(ll_ts)
ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib)

for j in range(ts.num_sites):
py_site = backward_cm.get_site(j)
lib_site = backward_cm.get_site(j)
assert len(py_site) == len(lib_site)
py_site = dict(py_site)
lib_site = dict(lib_site)
assert set(py_site.keys()) == set(lib_site.keys())
# NOTE this probably won't work always and we'll need to put in
# some wiggle. But, they should be identical values, up to precision.
# However, the C and Python round() implementations are slightly different
# so this will almost certainly break.
for node in py_site.keys():
assert py_site[node] == lib_site[node]

nt.assert_array_equal(cm_lib.normalisation_factor, forward_cm.normalisation_factor)
B_lib = cm_lib.decode()
# print(B_lib)
# print(B)
# nt.assert_array_almost_equal(B, B_lib)
nt.assert_array_almost_equal(B, B_lib)


def add_unique_sample_mutations(ts, start=0):
Expand Down

0 comments on commit 82e97eb

Please sign in to comment.