From 4d93a823f3b7f0eff8ec510e0ffc72a5991bf029 Mon Sep 17 00:00:00 2001 From: szhan Date: Mon, 1 Jul 2024 17:17:25 +0100 Subject: [PATCH] Add argument for pass function to define emission probabilities --- lshmm/api.py | 106 ++++++++++++------- lshmm/fb_haploid.py | 18 ++-- lshmm/vit_haploid.py | 76 +++++++++----- tests/test_api_fb_haploid.py | 5 +- tests/test_api_fb_haploid_multi.py | 3 + tests/test_api_vit_haploid_multi.py | 19 ++-- tests/test_nontree_fb_haploid.py | 46 ++++++++- tests/test_nontree_vit_haploid.py | 155 +++++++++++++++++++++++----- 8 files changed, 318 insertions(+), 110 deletions(-) diff --git a/lshmm/api.py b/lshmm/api.py index d6d662d..890a985 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -223,23 +223,34 @@ def forwards( ) if ploidy == 1: - forward_function = forwards_ls_hap + ( + forward_array, + normalisation_factor_from_forward, + log_lik, + ) = forwards_ls_hap( + num_ref_haps, + num_sites, + ref_panel_checked, + query_checked, + emission_matrix, + prob_recombination, + norm=normalise, + emission_func=core.get_emission_probability_haploid, + ) else: - forward_function = forward_ls_dip_loop - - ( - forward_array, - normalisation_factor_from_forward, - log_lik, - ) = forward_function( - num_ref_haps, - num_sites, - ref_panel_checked, - query_checked, - emission_matrix, - prob_recombination, - norm=normalise, - ) + ( + forward_array, + normalisation_factor_from_forward, + log_lik, + ) = forward_ls_dip_loop( + num_ref_haps, + num_sites, + ref_panel_checked, + query_checked, + emission_matrix, + prob_recombination, + norm=normalise, + ) return forward_array, normalisation_factor_from_forward, log_lik @@ -267,19 +278,26 @@ def backwards( ) if ploidy == 1: - backward_function = backwards_ls_hap + backwards_array = backwards_ls_hap( + num_ref_haps, + num_sites, + ref_panel_checked, + query_checked, + emission_matrix, + normalisation_factor_from_forward, + prob_recombination, + emission_func=core.get_emission_probability_haploid, + ) else: - backward_function = backward_ls_dip_loop - - backwards_array = backward_function( - num_ref_haps, - num_sites, - ref_panel_checked, - query_checked, - emission_matrix, - normalisation_factor_from_forward, - prob_recombination, - ) + backwards_array = backward_ls_dip_loop( + num_ref_haps, + num_sites, + ref_panel_checked, + query_checked, + emission_matrix, + normalisation_factor_from_forward, + prob_recombination, + ) return backwards_array @@ -313,6 +331,7 @@ def viterbi( query_checked, emission_matrix, prob_recombination, + emission_func=core.get_emission_probability_haploid, ) best_path = backwards_viterbi_hap(num_sites, V, P) else: @@ -353,18 +372,25 @@ def path_loglik( ) if ploidy == 1: - path_ll_function = path_ll_hap + log_lik = path_ll_hap( + num_ref_haps, + num_sites, + ref_panel_checked, + path, + query_checked, + emission_matrix, + prob_recombination, + emission_func=core.get_emission_probability_haploid, + ) else: - path_ll_function = path_ll_dip - - log_lik = path_ll_function( - num_ref_haps, - num_sites, - ref_panel_checked, - path, - query_checked, - emission_matrix, - prob_recombination, - ) + log_lik = path_ll_dip( + num_ref_haps, + num_sites, + ref_panel_checked, + path, + query_checked, + emission_matrix, + prob_recombination, + ) return log_lik diff --git a/lshmm/fb_haploid.py b/lshmm/fb_haploid.py index 7ae8460..8f7b311 100644 --- a/lshmm/fb_haploid.py +++ b/lshmm/fb_haploid.py @@ -7,7 +7,9 @@ @jit.numba_njit -def forwards_ls_hap(n, m, H, s, e, r, norm=True): +def forwards_ls_hap( + n, m, H, s, e, r, emission_func, *, norm=True, +): """ A matrix-based implementation using Numpy. @@ -20,7 +22,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): if norm: c = np.zeros(m) for i in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[0, i], query_allele=s[0, 0], site=0, @@ -36,7 +38,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l] - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[l, i], query_allele=s[0, l], site=l, @@ -53,7 +55,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): else: c = np.ones(m) for i in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[0, i], query_allele=s[0, 0], site=0, @@ -65,7 +67,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l] - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[l, i], query_allele=s[0, l], site=l, @@ -79,7 +81,9 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): @jit.numba_njit -def backwards_ls_hap(n, m, H, s, e, c, r): +def backwards_ls_hap( + n, m, H, s, e, c, r, *, emission_func, +): """ A matrix-based implementation using Numpy. @@ -96,7 +100,7 @@ def backwards_ls_hap(n, m, H, s, e, c, r): tmp_B = np.zeros(n) tmp_B_sum = 0 for i in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[l + 1, i], query_allele=s[0, l + 1], site=l + 1, diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index 87dbb97..b772324 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -7,7 +7,9 @@ @jit.numba_njit -def viterbi_naive_init(n, m, H, s, e, r): +def viterbi_naive_init( + n, m, H, s, e, r, *, emission_func, +): """Initialise a naive implementation.""" V = np.zeros((m, n)) P = np.zeros((m, n), dtype=np.int64) @@ -15,7 +17,7 @@ def viterbi_naive_init(n, m, H, s, e, r): r_n = r / num_copiable_entries for i in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[0, i], query_allele=s[0, 0], site=0, @@ -27,7 +29,9 @@ def viterbi_naive_init(n, m, H, s, e, r): @jit.numba_njit -def viterbi_init(n, m, H, s, e, r): +def viterbi_init( + n, m, H, s, e, r, *, emission_func, +): """Initialise a naive, but more memory efficient, implementation.""" V_prev = np.zeros(n) V = np.zeros(n) @@ -36,7 +40,7 @@ def viterbi_init(n, m, H, s, e, r): r_n = r / num_copiable_entries for i in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[0, i], query_allele=s[0, 0], site=0, @@ -48,15 +52,17 @@ def viterbi_init(n, m, H, s, e, r): @jit.numba_njit -def forwards_viterbi_hap_naive(n, m, H, s, e, r): +def forwards_viterbi_hap_naive( + n, m, H, s, e, r, *, emission_func, +): """A naive implementation of the forward pass.""" - V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) + V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func) for j in range(1, m): for i in range(n): v = np.zeros(n) for k in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[j, i], query_allele=s[0, j], site=j, @@ -76,16 +82,18 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r): @jit.numba_njit -def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): +def forwards_viterbi_hap_naive_vec( + n, m, H, s, e, r, *, emission_func, +): """A naive matrix-based implementation of the forward pass.""" - V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) + V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func) for j in range(1, m): v_tmp = V[j - 1, :] * r_n[j] for i in range(n): v = np.copy(v_tmp) v[i] += V[j - 1, i] * (1 - r[j]) - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[j, i], query_allele=s[0, j], site=j, @@ -101,15 +109,17 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): @jit.numba_njit -def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): +def forwards_viterbi_hap_naive_low_mem( + n, m, H, s, e, r, *, emission_func, +): """A naive implementation of the forward pass with reduced memory.""" - V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) for j in range(1, m): for i in range(n): v = np.zeros(n) for k in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[j, i], query_allele=s[0, j], site=j, @@ -130,9 +140,11 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): @jit.numba_njit -def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): +def forwards_viterbi_hap_naive_low_mem_rescaling( + n, m, H, s, e, r, *, emission_func, +): """A naive implementation of the forward pass with reduced memory and rescaling.""" - V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) c = np.ones(m) for j in range(1, m): @@ -141,7 +153,7 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): for i in range(n): v = np.zeros(n) for k in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[j, i], query_allele=s[0, j], site=j, @@ -162,9 +174,11 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit -def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): +def forwards_viterbi_hap_low_mem_rescaling( + n, m, H, s, e, r, *, emission_func, +): """An implementation with reduced memory that exploits the Markov structure.""" - V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) c = np.ones(m) for j in range(1, m): @@ -178,7 +192,7 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[j, i], query_allele=s[0, j], site=j, @@ -193,7 +207,9 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit -def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): +def forwards_viterbi_hap_lower_mem_rescaling( + n, m, H, s, e, r, *, emission_func, +): """ An implementation with even smaller memory footprint that exploits the Markov structure. @@ -202,7 +218,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): """ V = np.zeros(n) for i in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[0, i], query_allele=s[0, 0], site=0, @@ -224,7 +240,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[j, i], query_allele=s[0, j], site=j, @@ -238,14 +254,16 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit -def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): +def forwards_viterbi_hap_lower_mem_rescaling_no_pointer( + n, m, H, s, e, r, *, emission_func, +): """ An implementation with even smaller memory footprint and rescaling that exploits the Markov structure. """ V = np.zeros(n) for i in range(n): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[0, i], query_allele=s[0, 0], site=0, @@ -273,7 +291,7 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): recombs[j] = np.append( recombs[j], i ) # We add template i as a potential template to recombine to at site j. - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[j, i], query_allele=s[0, j], site=j, @@ -320,13 +338,15 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): @jit.numba_njit -def path_ll_hap(n, m, H, path, s, e, r): +def path_ll_hap( + n, m, H, path, s, e, r, *, emission_func, +): """ Evaluate the log-likelihood of a path through a reference panel resulting in a query. This is exposed via the API. """ - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[0, path[0]], query_allele=s[0, 0], site=0, @@ -338,7 +358,7 @@ def path_ll_hap(n, m, H, path, s, e, r): r_n = r / num_copiable_entries for l in range(1, m): - emission_prob = core.get_emission_probability_haploid( + emission_prob = emission_func( ref_allele=H[l, path[l]], query_allele=s[0, l], site=l, diff --git a/tests/test_api_fb_haploid.py b/tests/test_api_fb_haploid.py index 7238283..1cf9656 100644 --- a/tests/test_api_fb_haploid.py +++ b/tests/test_api_fb_haploid.py @@ -16,7 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_ancestors=include_ancestors, include_extreme_rates=True, ): - num_alleles = core.get_num_alleles(H_vs, s) + emission_func = core.get_emission_probability_haploid F_vs, c_vs, ll_vs = fbh.forwards_ls_hap( n=n, m=m, @@ -24,6 +24,8 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): s=s, e=e_vs, r=r, + emission_func=emission_func, + norm=True, ) B_vs = fbh.backwards_ls_hap( n=n, @@ -33,6 +35,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): e=e_vs, c=c_vs, r=r, + emission_func=emission_func, ) F, c, ll = ls.forwards( reference_panel=H_vs, diff --git a/tests/test_api_fb_haploid_multi.py b/tests/test_api_fb_haploid_multi.py index a90f57a..77b1b6d 100644 --- a/tests/test_api_fb_haploid_multi.py +++ b/tests/test_api_fb_haploid_multi.py @@ -16,6 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_ancestors=include_ancestors, include_extreme_rates=True, ): + emission_func = core.get_emission_probability_haploid F_vs, c_vs, ll_vs = fbh.forwards_ls_hap( n=n, m=m, @@ -23,6 +24,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): s=s, e=e_vs, r=r, + emission_func=emission_func, ) B_vs = fbh.backwards_ls_hap( n=n, @@ -32,6 +34,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): e=e_vs, c=c_vs, r=r, + emission_func=emission_func, ) F, c, ll = ls.forwards( reference_panel=H_vs, diff --git a/tests/test_api_vit_haploid_multi.py b/tests/test_api_vit_haploid_multi.py index 5020171..d7dce8a 100644 --- a/tests/test_api_vit_haploid_multi.py +++ b/tests/test_api_vit_haploid_multi.py @@ -16,6 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_ancestors=include_ancestors, include_extreme_rates=True, ): + emission_func = core.get_emission_probability_haploid V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n=n, m=m, @@ -23,9 +24,19 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): s=s, e=e_vs, r=r, + emission_func=emission_func, ) path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs) - path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) + path_ll_hap = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) path, ll = ls.viterbi( reference_panel=H_vs, query=s, @@ -44,11 +55,7 @@ def test_ts_multiallelic_n10_no_recomb( self, scale_mutation_rate, include_ancestors ): ts = self.get_ts_multiallelic_n10_no_recomb() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) + self.verify(ts, scale_mutation_rate, include_ancestors) @pytest.mark.parametrize("num_samples", [6, 8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) diff --git a/tests/test_nontree_fb_haploid.py b/tests/test_nontree_fb_haploid.py index 7f76a0f..331aabe 100644 --- a/tests/test_nontree_fb_haploid.py +++ b/tests/test_nontree_fb_haploid.py @@ -10,20 +10,56 @@ class TestNonTreeForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): - F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) - B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) + emission_func = core.get_emission_probability_haploid + F_vs, c_vs, ll_vs = fbh.forwards_ls_hap( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + norm=False, + ) + B_vs = fbh.backwards_ls_hap( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + c=c_vs, + r=r, + emission_func=emission_func, + ) self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( - n, m, H_vs, s, e_vs, r, norm=True + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + norm=True, + ) + B_tmp = fbh.backwards_ls_hap( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + c=c_tmp, + r=r, + emission_func=emission_func, ) - B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) self.assertAllClose(ll_vs, ll_tmp) diff --git a/tests/test_nontree_vit_haploid.py b/tests/test_nontree_vit_haploid.py index 93bd7a5..475d121 100644 --- a/tests/test_nontree_vit_haploid.py +++ b/tests/test_nontree_vit_haploid.py @@ -10,54 +10,148 @@ class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): - V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) - path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) - ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) + emission_func = core.get_emission_probability_haploid + + V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_vs = vh.backwards_viterbi_hap( + m=m, + V_last=V_vs[m - 1, :], + P=P_vs, + ) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) self.assertAllClose(ll_vs, ll_check) V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( - n, m, H_vs, s, e_vs, r + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap( + m=m, + V_last=V_tmp[m - 1, :], + P=P_tmp, + ) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( - n, m, H_vs, s, e_vs, r + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( - n, m, H_vs, s, e_vs, r + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( - n, m, H_vs, s, e_vs, r + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( - n, m, H_vs, s, e_vs, r + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) @@ -68,14 +162,29 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): recombs, ll_tmp, ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( - n, m, H_vs, s, e_vs, r + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, ) path_tmp = vh.backwards_viterbi_hap_no_pointer( - m, - V_argmaxes_tmp, - nb.typed.List(recombs), + m=m, + V_argmaxes=V_argmaxes_tmp, + recombs=nb.typed.List(recombs), + ) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, ) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp)