Skip to content

Commit

Permalink
Add argument for pass function to define emission probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 1, 2024
1 parent 5f49c42 commit 4d93a82
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 110 deletions.
106 changes: 66 additions & 40 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,23 +223,34 @@ def forwards(
)

if ploidy == 1:
forward_function = forwards_ls_hap
(
forward_array,
normalisation_factor_from_forward,
log_lik,
) = forwards_ls_hap(
num_ref_haps,
num_sites,
ref_panel_checked,
query_checked,
emission_matrix,
prob_recombination,
norm=normalise,
emission_func=core.get_emission_probability_haploid,
)
else:
forward_function = forward_ls_dip_loop

(
forward_array,
normalisation_factor_from_forward,
log_lik,
) = forward_function(
num_ref_haps,
num_sites,
ref_panel_checked,
query_checked,
emission_matrix,
prob_recombination,
norm=normalise,
)
(
forward_array,
normalisation_factor_from_forward,
log_lik,
) = forward_ls_dip_loop(
num_ref_haps,
num_sites,
ref_panel_checked,
query_checked,
emission_matrix,
prob_recombination,
norm=normalise,
)

return forward_array, normalisation_factor_from_forward, log_lik

Expand Down Expand Up @@ -267,19 +278,26 @@ def backwards(
)

if ploidy == 1:
backward_function = backwards_ls_hap
backwards_array = backwards_ls_hap(
num_ref_haps,
num_sites,
ref_panel_checked,
query_checked,
emission_matrix,
normalisation_factor_from_forward,
prob_recombination,
emission_func=core.get_emission_probability_haploid,
)
else:
backward_function = backward_ls_dip_loop

backwards_array = backward_function(
num_ref_haps,
num_sites,
ref_panel_checked,
query_checked,
emission_matrix,
normalisation_factor_from_forward,
prob_recombination,
)
backwards_array = backward_ls_dip_loop(
num_ref_haps,
num_sites,
ref_panel_checked,
query_checked,
emission_matrix,
normalisation_factor_from_forward,
prob_recombination,
)

return backwards_array

Expand Down Expand Up @@ -313,6 +331,7 @@ def viterbi(
query_checked,
emission_matrix,
prob_recombination,
emission_func=core.get_emission_probability_haploid,
)
best_path = backwards_viterbi_hap(num_sites, V, P)
else:
Expand Down Expand Up @@ -353,18 +372,25 @@ def path_loglik(
)

if ploidy == 1:
path_ll_function = path_ll_hap
log_lik = path_ll_hap(
num_ref_haps,
num_sites,
ref_panel_checked,
path,
query_checked,
emission_matrix,
prob_recombination,
emission_func=core.get_emission_probability_haploid,
)
else:
path_ll_function = path_ll_dip

log_lik = path_ll_function(
num_ref_haps,
num_sites,
ref_panel_checked,
path,
query_checked,
emission_matrix,
prob_recombination,
)
log_lik = path_ll_dip(
num_ref_haps,
num_sites,
ref_panel_checked,
path,
query_checked,
emission_matrix,
prob_recombination,
)

return log_lik
18 changes: 11 additions & 7 deletions lshmm/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


@jit.numba_njit
def forwards_ls_hap(n, m, H, s, e, r, norm=True):
def forwards_ls_hap(
n, m, H, s, e, r, emission_func, *, norm=True,
):
"""
A matrix-based implementation using Numpy.
Expand All @@ -20,7 +22,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
if norm:
c = np.zeros(m)
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand All @@ -36,7 +38,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l]
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[l, i],
query_allele=s[0, l],
site=l,
Expand All @@ -53,7 +55,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
else:
c = np.ones(m)
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand All @@ -65,7 +67,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l]
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[l, i],
query_allele=s[0, l],
site=l,
Expand All @@ -79,7 +81,9 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):


@jit.numba_njit
def backwards_ls_hap(n, m, H, s, e, c, r):
def backwards_ls_hap(
n, m, H, s, e, c, r, *, emission_func,
):
"""
A matrix-based implementation using Numpy.
Expand All @@ -96,7 +100,7 @@ def backwards_ls_hap(n, m, H, s, e, c, r):
tmp_B = np.zeros(n)
tmp_B_sum = 0
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[l + 1, i],
query_allele=s[0, l + 1],
site=l + 1,
Expand Down
Loading

0 comments on commit 4d93a82

Please sign in to comment.