From 5dcebd87befb5508819eac8e3e421716827e446d Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 21 Jun 2024 09:48:46 +0100 Subject: [PATCH] Refactor setting default of scale_mutation_rate --- lshmm/api.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/lshmm/api.py b/lshmm/api.py index 2e37c6b..e4642e2 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -55,7 +55,7 @@ def check_inputs( :param numpy.ndarray num_alleles: Number of distinct alleles per site. :param numpy.ndarray prob_recombination: Recombination probability. :param numpy.ndarray prob_mutation: Mutation probability. - :param bool scale_mutation_rate: Scale mutation rate. + :param bool scale_mutation_rate: Scale mutation rate or not. :return: Number of ref. haplotypes, number of sites, ploidy, emission prob. matrix. :rtype: tuple """ @@ -113,6 +113,10 @@ def check_inputs( ) raise ValueError(err_msg) + # Set whether to scale mutation rates if not set already. + if scale_mutation_rate is None: + scale_mutation_rate = True + # Check the mutation probability. if prob_mutation is None: warn_msg = "No mutation probability is passed; setting it as per Li & Stephens (2003) eqn. A2 and A3." @@ -162,9 +166,6 @@ def forwards( normalise=None, ): """Run the forwards algorithm on haploid or diploid genotype data.""" - if scale_mutation_rate is None: - scale_mutation_rate = True - if normalise is None: normalise = True @@ -210,9 +211,6 @@ def backwards( scale_mutation_rate=None, ): """Run the backwards algorithm on haploid or diploid genotype data.""" - if scale_mutation_rate is None: - scale_mutation_rate = True - num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( reference_panel=reference_panel, query=query, @@ -250,9 +248,6 @@ def viterbi( scale_mutation_rate=None, ): """Run the Viterbi algorithm on haploid or diploid genotype data.""" - if scale_mutation_rate is None: - scale_mutation_rate = True - num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( reference_panel=reference_panel, query=query, @@ -298,9 +293,6 @@ def path_loglik( scale_mutation_rate=None, ): """Evaluate the log-likelihood of a copying path for a query through a reference panel.""" - if scale_mutation_rate is None: - scale_mutation_rate = True - num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( reference_panel=reference_panel, query=query,