Skip to content

Commit

Permalink
Merge pull request #27 from astheeggeggs/prob_not_rate
Browse files Browse the repository at this point in the history
updated rate to commit where appropriate
  • Loading branch information
astheeggeggs authored Jul 20, 2023
2 parents 792c74b + a9193eb commit a743130
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 87 deletions.
138 changes: 67 additions & 71 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def check_alleles(alleles, m):
def checks(
reference_panel,
query,
mutation_rate,
recombination_rate,
p_mutation,
p_recombination,
scale_mutation_based_on_n_alleles,
):
"""
Expand All @@ -67,9 +67,9 @@ def checks(
:param numpy.ndarray(dtype=int) reference_panel: Matrix of size (m, n) or (m, n, n).
:param numpy.ndarray(dtype=int) query: Matrix of size (k, m) or (k, m, 2).
:param numpy.ndarray(dtype=float) mutation_rate: Scalar or vector of length m.
:param numpy.ndarray(dtype=float) recombination_rate: Scalar or vector of length m.
:param bool scale_mutation_based_on_n_alleles: Whether to scale mutation rate based on the number of alleles (True) or not (False).
:param numpy.ndarray(dtype=float) p_mutation: Scalar or vector of length m.
:param numpy.ndarray(dtype=float) p_recombination: Scalar or vector of length m.
:param bool scale_mutation_based_on_n_alleles: Whether to scale the mutation probability to the set of alleles that can be mutated to based on the number of alleles (True) or not (False).
:return: n, m, ploidy
:rtype: tuple
"""
Expand All @@ -95,37 +95,32 @@ def checks(
"Number of sites in query does not match reference panel. If haploid, ensure a sites x samples matrix is passed."
)

# Ensure that the mutation rate is either a scalar or vector of length m
if isinstance(mutation_rate, (int, float)):
# Ensure that the mutation probability is either a scalar or vector of length m
if isinstance(p_mutation, (int, float)):
if not scale_mutation_based_on_n_alleles:
warnings.warn(
"Passed a scalar mutation rate, but not rescaling this mutation rate conditional on the number of alleles at the site."
"Passed a scalar probability of mutation, but not rescaling this probability of mutation conditional on the number of alleles at the site."
)
elif isinstance(mutation_rate, np.ndarray) and mutation_rate.shape[0] == m:
elif isinstance(p_mutation, np.ndarray) and p_mutation.shape[0] == m:
if scale_mutation_based_on_n_alleles:
warnings.warn(
"Passed a vector of mutation rates, but rescaling each mutation rate conditional on the number of alleles at each site."
"Passed a vector of probabilities of mutation, but rescaling each mutation probability conditional on the number of alleles at each site."
)
elif mutation_rate is None:
elif p_mutation is None:
warnings.warn(
"No mutation rate passed, setting mutation rate based on Li and Stephens 2003, equations (A2) and (A3)"
"No mutation probability passed, setting mutation probability based on Li and Stephens 2003, equations (A2) and (A3)"
)
else:
raise ValueError(
f"Mutation rate is not None, a scalar, or vector of length m: {m}"
f"Mutation probability is not None, a scalar, or vector of length m: {m}"
)

# Ensure that the recombination rate is either a scalar or a vector of length m
# Ensure that the recombination probability is either a scalar or a vector of length m
if not (
isinstance(recombination_rate, (int, float))
or (
isinstance(recombination_rate, np.ndarray)
and recombination_rate.shape[0] == m
)
isinstance(p_recombination, (int, float))
or (isinstance(p_recombination, np.ndarray) and p_recombination.shape[0] == m)
):
raise ValueError(
f"Recombination_rate is not a scalar or vector of length m: {m}"
)
raise ValueError(f"p_Recombination is not a scalar or vector of length m: {m}")

return (n, m, ploidy)

Expand All @@ -136,7 +131,7 @@ def set_emission_probabilities(
reference_panel,
query,
alleles,
mutation_rate,
p_mutation,
ploidy,
scale_mutation_based_on_n_alleles,
):
Expand All @@ -152,65 +147,66 @@ def set_emission_probabilities(
else:
n_alleles = check_alleles(alleles, m)

if mutation_rate is None:
# Set the mutation rate to be the proposed mutation rate in Li and Stephens (2003).
if p_mutation is None:
# Set the mutation probability to be the proposed mutation probability in Li and Stephens (2003).
theta_tilde = 1 / np.sum([1 / k for k in range(1, n - 1)])
mutation_rate = 0.5 * (theta_tilde / (n + theta_tilde))
p_mutation = 0.5 * (theta_tilde / (n + theta_tilde))

if isinstance(mutation_rate, float):
mutation_rate = mutation_rate * np.ones(m)
if isinstance(p_mutation, float):
p_mutation = p_mutation * np.ones(m)

if ploidy == 1:
# Haploid
# Evaluate emission probabilities here, using the mutation rate - this can take a scalar or vector.
# Evaluate emission probabilities here using p_mutation - this can take a scalar or vector.
e = np.zeros((m, 2))

if scale_mutation_based_on_n_alleles:
# Scale mutation based on the number of alleles - so the mutation rate is the mutation rate to one of the alleles.
# The overall mutation rate is then (n_alleles - 1) * mutation_rate.
e[:, 0] = mutation_rate - mutation_rate * np.equal(
# Scale mutation based on the number of alleles - so p_mutation is probability of mutation any given one of the alleles.
# The overall mutation probability is then (n_alleles - 1) * p_mutation.
e[:, 0] = p_mutation - p_mutation * np.equal(
n_alleles, np.ones(m)
) # Added boolean in case we're at an invariant site
e[:, 1] = 1 - (n_alleles - 1) * mutation_rate
e[:, 1] = 1 - (n_alleles - 1) * p_mutation
else:
# No scaling based on the number of alleles - so the mutation rate is the mutation rate to anything.
# Which means that we must rescale the mutation rate to a different allele, by the number of alleles.
# No scaling based on the number of alleles - so p_mutation is the probability of mutation to anything
# (summing over the states we can switch to). This means that we must rescale the probability of mutation to
# a different allele by the number of alleles at the site.
for j in range(m):
if n_alleles[j] == 1: # In case we're at an invariant site
e[j, 0] = 0
e[j, 1] = 1
else:
e[j, 0] = mutation_rate[j] / (n_alleles[j] - 1)
e[j, 1] = 1 - mutation_rate[j]
e[j, 0] = p_mutation[j] / (n_alleles[j] - 1)
e[j, 1] = 1 - p_mutation[j]
else:
# Diploid
# Evaluate emission probabilities here, using the mutation rate - this can take a scalar or vector.
# Evaluate emission probabilities here, using the mutation probability - this can take a scalar or vector.
# DEV: there's a wrinkle here.
e = np.zeros((m, 8))
e[:, EQUAL_BOTH_HOM] = (1 - mutation_rate) ** 2
e[:, UNEQUAL_BOTH_HOM] = mutation_rate ** 2
e[:, BOTH_HET] = (1 - mutation_rate) ** 2 + mutation_rate ** 2
e[:, REF_HOM_OBS_HET] = 2 * mutation_rate * (1 - mutation_rate)
e[:, REF_HET_OBS_HOM] = mutation_rate * (1 - mutation_rate)
e[:, EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2
e[:, UNEQUAL_BOTH_HOM] = p_mutation ** 2
e[:, BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation ** 2
e[:, REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation)
e[:, REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation)
e[:, MISSING_INDEX] = 1

return e


def viterbi_hap(n, m, reference_panel, query, emissions, recombination_rate):
def viterbi_hap(n, m, reference_panel, query, emissions, p_recombination):

V, P, log_likelihood = forwards_viterbi_hap_lower_mem_rescaling(
n, m, reference_panel, query, emissions, recombination_rate
n, m, reference_panel, query, emissions, p_recombination
)
most_likely_path = backwards_viterbi_hap(m, V, P)

return most_likely_path, log_likelihood


def viterbi_dip(n, m, reference_panel, query, emissions, recombination_rate):
def viterbi_dip(n, m, reference_panel, query, emissions, p_recombination):

V, P, log_likelihood = forwards_viterbi_dip_low_mem(
n, m, reference_panel, query, emissions, recombination_rate
n, m, reference_panel, query, emissions, p_recombination
)
unphased_path = backwards_viterbi_dip(m, V, P)
most_likely_path = get_phased_path(n, unphased_path)
Expand All @@ -221,9 +217,9 @@ def viterbi_dip(n, m, reference_panel, query, emissions, recombination_rate):
def forwards(
reference_panel,
query,
recombination_rate,
p_recombination,
alleles=None,
mutation_rate=None,
p_mutation=None,
scale_mutation_based_on_n_alleles=True,
norm=True,
):
Expand All @@ -234,8 +230,8 @@ def forwards(
n, m, ploidy = checks(
reference_panel,
query,
mutation_rate,
recombination_rate,
p_mutation,
p_recombination,
scale_mutation_based_on_n_alleles,
)

Expand All @@ -245,7 +241,7 @@ def forwards(
reference_panel,
query,
alleles,
mutation_rate,
p_mutation,
ploidy,
scale_mutation_based_on_n_alleles,
)
Expand All @@ -260,7 +256,7 @@ def forwards(
normalisation_factor_from_forward,
log_likelihood,
) = forward_function(
n, m, reference_panel, query, emissions, recombination_rate, norm=norm
n, m, reference_panel, query, emissions, p_recombination, norm=norm
)

return forward_array, normalisation_factor_from_forward, log_likelihood
Expand All @@ -270,9 +266,9 @@ def backwards(
reference_panel,
query,
normalisation_factor_from_forward,
recombination_rate,
p_recombination,
alleles=None,
mutation_rate=None,
p_mutation=None,
scale_mutation_based_on_n_alleles=True,
):
"""
Expand All @@ -282,8 +278,8 @@ def backwards(
n, m, ploidy = checks(
reference_panel,
query,
mutation_rate,
recombination_rate,
p_mutation,
p_recombination,
scale_mutation_based_on_n_alleles,
)

Expand All @@ -293,7 +289,7 @@ def backwards(
reference_panel,
query,
alleles,
mutation_rate,
p_mutation,
ploidy,
scale_mutation_based_on_n_alleles,
)
Expand All @@ -310,7 +306,7 @@ def backwards(
query,
emissions,
normalisation_factor_from_forward,
recombination_rate,
p_recombination,
)

return backwards_array
Expand All @@ -319,9 +315,9 @@ def backwards(
def viterbi(
reference_panel,
query,
recombination_rate,
p_recombination,
alleles=None,
mutation_rate=None,
p_mutation=None,
scale_mutation_based_on_n_alleles=True,
):
"""
Expand All @@ -331,8 +327,8 @@ def viterbi(
n, m, ploidy = checks(
reference_panel,
query,
mutation_rate,
recombination_rate,
p_mutation,
p_recombination,
scale_mutation_based_on_n_alleles,
)

Expand All @@ -342,7 +338,7 @@ def viterbi(
reference_panel,
query,
alleles,
mutation_rate,
p_mutation,
ploidy,
scale_mutation_based_on_n_alleles,
)
Expand All @@ -353,7 +349,7 @@ def viterbi(
viterbi_function = viterbi_dip

most_likely_path, log_likelihood = viterbi_function(
n, m, reference_panel, query, emissions, recombination_rate
n, m, reference_panel, query, emissions, p_recombination
)

return most_likely_path, log_likelihood
Expand All @@ -363,17 +359,17 @@ def path_ll(
reference_panel,
query,
path,
recombination_rate,
p_recombination,
alleles=None,
mutation_rate=None,
p_mutation=None,
scale_mutation_based_on_n_alleles=True,
):

n, m, ploidy = checks(
reference_panel,
query,
mutation_rate,
recombination_rate,
p_mutation,
p_recombination,
scale_mutation_based_on_n_alleles,
)

Expand All @@ -383,7 +379,7 @@ def path_ll(
reference_panel,
query,
alleles,
mutation_rate,
p_mutation,
ploidy,
scale_mutation_based_on_n_alleles,
)
Expand All @@ -394,7 +390,7 @@ def path_ll(
path_ll_function = path_ll_dip

ll = path_ll_function(
n, m, reference_panel, path, query, emissions, recombination_rate
n, m, reference_panel, path, query, emissions, p_recombination
)

return ll
14 changes: 7 additions & 7 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def genotype_emission(self, mu, m):

def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True):
"""Returns an iterator over combinations of haplotype, recombination and
mutation rates."""
mutation probabilities."""
np.random.seed(seed)
H, haplotypes = self.example_haplotypes(ts)
n = H.shape[1]
Expand Down Expand Up @@ -240,8 +240,8 @@ def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts):
F_vs, c_vs, ll_vs = fbh_vs.forwards_ls_hap(n, m, H_vs, s, e_vs, r)
B_vs = fbh_vs.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
F, c, ll = ls.forwards(H_vs, s, r, mutation_rate=mu)
B = ls.backwards(H_vs, s, c, r, mutation_rate=mu)
F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu)
B = ls.backwards(H_vs, s, c, r, p_mutation=mu)
self.assertAllClose(F, F_vs)
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)
Expand All @@ -259,9 +259,9 @@ def verify(self, ts):
F_vs, c_vs, ll_vs = fbd_vs.forward_ls_dip_loop(
n, m, G_vs, s, e_vs, r, norm=True
)
F, c, ll = ls.forwards(G_vs, s, r, mutation_rate=mu)
F, c, ll = ls.forwards(G_vs, s, r, p_mutation=mu)
B_vs = fbd_vs.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_vs, r)
B = ls.backwards(G_vs, s, c, r, mutation_rate=mu)
B = ls.backwards(G_vs, s, c, r, p_mutation=mu)
self.assertAllClose(F, F_vs)
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)
Expand All @@ -281,7 +281,7 @@ def verify(self, ts):
n, m, H_vs, s, e_vs, r
)
path_vs = vh_vs.backwards_viterbi_hap(m, V_vs, P_vs)
path, ll = ls.viterbi(H_vs, s, r, mutation_rate=mu)
path, ll = ls.viterbi(H_vs, s, r, p_mutation=mu)

self.assertAllClose(ll_vs, ll)
self.assertAllClose(path_vs, path)
Expand All @@ -298,7 +298,7 @@ def verify(self, ts):
)
path_vs = vd_vs.backwards_viterbi_dip(m, V_vs, P_vs)
phased_path_vs = vd_vs.get_phased_path(n, path_vs)
path, ll = ls.viterbi(G_vs, s, r, mutation_rate=mu)
path, ll = ls.viterbi(G_vs, s, r, p_mutation=mu)

self.assertAllClose(ll_vs, ll)
self.assertAllClose(phased_path_vs, path)
Loading

0 comments on commit a743130

Please sign in to comment.