Skip to content

Commit

Permalink
Merge pull request #131 from szhan/refactor_scale_mutation_rate
Browse files Browse the repository at this point in the history
Refactor setting default of scale_mutation_rate
  • Loading branch information
szhan authored Jun 21, 2024
2 parents 3375578 + 5dcebd8 commit 8d474c5
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check_inputs(
:param numpy.ndarray num_alleles: Number of distinct alleles per site.
:param numpy.ndarray prob_recombination: Recombination probability.
:param numpy.ndarray prob_mutation: Mutation probability.
:param bool scale_mutation_rate: Scale mutation rate.
:param bool scale_mutation_rate: Scale mutation rate or not.
:return: Number of ref. haplotypes, number of sites, ploidy, emission prob. matrix.
:rtype: tuple
"""
Expand Down Expand Up @@ -113,6 +113,10 @@ def check_inputs(
)
raise ValueError(err_msg)

# Set whether to scale mutation rates if not set already.
if scale_mutation_rate is None:
scale_mutation_rate = True

# Check the mutation probability.
if prob_mutation is None:
warn_msg = "No mutation probability is passed; setting it as per Li & Stephens (2003) eqn. A2 and A3."
Expand Down Expand Up @@ -162,9 +166,6 @@ def forwards(
normalise=None,
):
"""Run the forwards algorithm on haploid or diploid genotype data."""
if scale_mutation_rate is None:
scale_mutation_rate = True

if normalise is None:
normalise = True

Expand Down Expand Up @@ -210,9 +211,6 @@ def backwards(
scale_mutation_rate=None,
):
"""Run the backwards algorithm on haploid or diploid genotype data."""
if scale_mutation_rate is None:
scale_mutation_rate = True

num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
Expand Down Expand Up @@ -250,9 +248,6 @@ def viterbi(
scale_mutation_rate=None,
):
"""Run the Viterbi algorithm on haploid or diploid genotype data."""
if scale_mutation_rate is None:
scale_mutation_rate = True

num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
Expand Down Expand Up @@ -298,9 +293,6 @@ def path_loglik(
scale_mutation_rate=None,
):
"""Evaluate the log-likelihood of a copying path for a query through a reference panel."""
if scale_mutation_rate is None:
scale_mutation_rate = True

num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
Expand Down

0 comments on commit 8d474c5

Please sign in to comment.