From 6c62be79cbf5cd18e57fcce921b5a81f00a9823e Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 21 Jun 2024 10:40:35 +0100 Subject: [PATCH] Rework check_inputs to no longer take num_alleles --- lshmm/api.py | 173 ++++++++++++++++------------ tests/test_api_fb_diploid.py | 16 +-- tests/test_api_fb_haploid.py | 7 +- tests/test_api_fb_haploid_multi.py | 8 +- tests/test_api_vit_diploid.py | 10 +- tests/test_api_vit_haploid.py | 6 +- tests/test_api_vit_haploid_multi.py | 6 +- 7 files changed, 129 insertions(+), 97 deletions(-) diff --git a/lshmm/api.py b/lshmm/api.py index e4642e2..ef76468 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -29,7 +29,7 @@ def check_inputs( reference_panel, query, - num_alleles, + ploidy, prob_recombination, prob_mutation, scale_mutation_rate, @@ -38,63 +38,66 @@ def check_inputs( Check that the input data and parameters are valid, and return data to run the HMM algorithms. - The reference panel must be an array of size (m, n) in the haploid case or - (m, n, n) in the diploid case, and the query must be an array of size (k, m), + The reference panel and query are arrays of size (m, n) and (k, m), respectively, where: m = number of sites. n = number of samples in the reference panel (haplotypes, not individuals). k = number of samples in the query (haplotypes, not individuals). - TODO: Support running on multiple queries. + TODO: Support running on multiple queries. Currently, only k = 1 or 2 is supported. The mutation rate can be scaled according to the set of alleles that can be mutated to based on the number of distinct alleles at each site. :param numpy.ndarray reference_panel: A panel of reference sequences. :param numpy.ndarray query: A query sequence. - :param numpy.ndarray num_alleles: Number of distinct alleles per site. + :param numpy.ndarray ploidy: Ploidy (only 1 or 2 are supported). :param numpy.ndarray prob_recombination: Recombination probability. :param numpy.ndarray prob_mutation: Mutation probability. :param bool scale_mutation_rate: Scale mutation rate or not. - :return: Number of ref. haplotypes, number of sites, ploidy, emission prob. matrix. + :return: Num. ref. hap., num. sites, checked ref. panel, checked query, emission prob. matrix. :rtype: tuple """ - # Check the reference panel. - if not len(reference_panel.shape) in (2, 3): - err_msg = "Reference panel array has incorrect number of dimensions." + # Check ploidy. + if not ploidy in [1, 2]: + err_msg = "Only ploidy levels 1 and 2 are supported." raise ValueError(err_msg) - if len(reference_panel.shape) == 2: - num_sites, num_ref_haps = reference_panel.shape - ploidy = 1 - else: - num_sites, num_ref_haps, _num_ref_haps = reference_panel.shape - if num_ref_haps != _num_ref_haps: - err_msg = "Reference panel array has incorrect dimensions." - raise ValueError(err_msg) - ploidy = 2 + # Check the reference panel. + if not len(reference_panel.shape) == 2: + err_msg = "Reference panel array has incorrect dimensions." + raise ValueError(err_msg) if np.any(reference_panel == core.MISSING): err_msg = "Reference panel cannot have any MISSING values." raise ValueError(err_msg) + if ploidy == 2: + if not np.all(np.isin(reference_panel, [0, 1, core.NONCOPY])): + err_msg = "Reference panel has illegal alleles. " + err_msg += "Only 0/1 encoding is supported in diploid mode." + raise ValueError(err_msg) + + num_sites, num_ref_haps = reference_panel.shape + # Check the queries. - if query.shape[1] != num_sites: - err_msg = "Number of sites in the query and reference panel do not match." + if query.shape[0] != ploidy: + err_msg = "Query array has incorrect dimensions." raise ValueError(err_msg) - if np.any(query == core.NONCOPY): - err_msg = "Queries cannot have any NONCOPY values." + if query.shape[1] != num_sites: + err_msg = "Number of sites in the query and reference panel don't match." raise ValueError(err_msg) - # Check the number of distinct alleles per site. - if len(num_alleles) != num_sites: - err_msg = "Number of alleles is not an array of expected length." + if np.any(query == core.NONCOPY): + err_msg = "Query cannot have any NONCOPY values." raise ValueError(err_msg) - if not np.all(num_alleles > 0) or not np.issubdtype(num_alleles.dtype, np.integer): - err_msg = "Number of alleles must be positive integers." - raise ValueError(err_msg) + if ploidy == 2: + if not np.all(np.isin(query, [0, 1, core.MISSING])): + err_msg = "Query has illegal alleles. " + err_msg += "Only 0/1 encoding is supported in diploid mode." + raise ValueError(err_msg) # Check the recombination probability. if isinstance(prob_recombination, (int, float)): @@ -137,6 +140,7 @@ def check_inputs( raise ValueError(err_msg) # Calculate the emission probability matrix. + num_alleles = core.get_num_alleles(reference_panel, query) if ploidy == 1: emission_matrix = core.get_emission_matrix_haploid( mu=prob_mutation, @@ -152,13 +156,32 @@ def check_inputs( scale_mutation_rate=scale_mutation_rate, ) - return num_ref_haps, num_sites, ploidy, emission_matrix + if ploidy == 1: + return ( + num_ref_haps, + num_sites, + reference_panel, + query, + emission_matrix, + ) + else: + ref_panel_genotypes = core.convert_haplotypes_to_phased_genotypes( + reference_panel + ) + query_genotypes = core.convert_haplotypes_to_unphased_genotypes(query) + return ( + num_ref_haps, + num_sites, + ref_panel_genotypes, + query_genotypes, + emission_matrix, + ) def forwards( reference_panel, query, - num_alleles, + ploidy, prob_recombination, *, prob_mutation=None, @@ -169,13 +192,15 @@ def forwards( if normalise is None: normalise = True - num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( - reference_panel=reference_panel, - query=query, - num_alleles=num_alleles, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, + num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = ( + check_inputs( + reference_panel=reference_panel, + query=query, + ploidy=ploidy, + prob_recombination=prob_recombination, + prob_mutation=prob_mutation, + scale_mutation_rate=scale_mutation_rate, + ) ) if ploidy == 1: @@ -190,8 +215,8 @@ def forwards( ) = forward_function( num_ref_haps, num_sites, - reference_panel, - query, + ref_panel_checked, + query_checked, emission_matrix, prob_recombination, norm=normalise, @@ -203,7 +228,7 @@ def forwards( def backwards( reference_panel, query, - num_alleles, + ploidy, normalisation_factor_from_forward, prob_recombination, *, @@ -211,13 +236,15 @@ def backwards( scale_mutation_rate=None, ): """Run the backwards algorithm on haploid or diploid genotype data.""" - num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( - reference_panel=reference_panel, - query=query, - num_alleles=num_alleles, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, + num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = ( + check_inputs( + reference_panel=reference_panel, + query=query, + ploidy=ploidy, + prob_recombination=prob_recombination, + prob_mutation=prob_mutation, + scale_mutation_rate=scale_mutation_rate, + ) ) if ploidy == 1: @@ -228,8 +255,8 @@ def backwards( backwards_array = backward_function( num_ref_haps, num_sites, - reference_panel, - query, + ref_panel_checked, + query_checked, emission_matrix, normalisation_factor_from_forward, prob_recombination, @@ -241,28 +268,30 @@ def backwards( def viterbi( reference_panel, query, - num_alleles, + ploidy, prob_recombination, *, prob_mutation=None, scale_mutation_rate=None, ): """Run the Viterbi algorithm on haploid or diploid genotype data.""" - num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( - reference_panel=reference_panel, - query=query, - num_alleles=num_alleles, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, + num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = ( + check_inputs( + reference_panel=reference_panel, + query=query, + ploidy=ploidy, + prob_recombination=prob_recombination, + prob_mutation=prob_mutation, + scale_mutation_rate=scale_mutation_rate, + ) ) if ploidy == 1: V, P, log_lik = forwards_viterbi_hap_lower_mem_rescaling( num_ref_haps, num_sites, - reference_panel, - query, + ref_panel_checked, + query_checked, emission_matrix, prob_recombination, ) @@ -271,8 +300,8 @@ def viterbi( V, P, log_lik = forwards_viterbi_dip_low_mem( num_ref_haps, num_sites, - reference_panel, - query, + ref_panel_checked, + query_checked, emission_matrix, prob_recombination, ) @@ -285,7 +314,7 @@ def viterbi( def path_loglik( reference_panel, query, - num_alleles, + ploidy, path, prob_recombination, *, @@ -293,13 +322,15 @@ def path_loglik( scale_mutation_rate=None, ): """Evaluate the log-likelihood of a copying path for a query through a reference panel.""" - num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( - reference_panel=reference_panel, - query=query, - num_alleles=num_alleles, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, + num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = ( + check_inputs( + reference_panel=reference_panel, + query=query, + ploidy=ploidy, + prob_recombination=prob_recombination, + prob_mutation=prob_mutation, + scale_mutation_rate=scale_mutation_rate, + ) ) if ploidy == 1: @@ -310,9 +341,9 @@ def path_loglik( log_lik = path_ll_function( num_ref_haps, num_sites, - reference_panel, + ref_panel_checked, path, - query, + query_checked, emission_matrix, prob_recombination, ) diff --git a/tests/test_api_fb_diploid.py b/tests/test_api_fb_diploid.py index a5b485f..a092beb 100644 --- a/tests/test_api_fb_diploid.py +++ b/tests/test_api_fb_diploid.py @@ -8,16 +8,16 @@ class TestForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 2 for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=2, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) - num_alleles = core.get_num_alleles(H_vs, query) F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop( n=n, @@ -38,18 +38,18 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, ) F, c, ll = ls.forwards( - reference_panel=G_vs, - query=s, - num_alleles=num_alleles, + reference_panel=H_vs, + query=query, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, normalise=True, ) B = ls.backwards( - reference_panel=G_vs, - query=s, - num_alleles=num_alleles, + reference_panel=H_vs, + query=query, + ploidy=ploidy, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, diff --git a/tests/test_api_fb_haploid.py b/tests/test_api_fb_haploid.py index 3f5220c..058c575 100644 --- a/tests/test_api_fb_haploid.py +++ b/tests/test_api_fb_haploid.py @@ -8,9 +8,10 @@ class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, @@ -36,7 +37,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): F, c, ll = ls.forwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, @@ -45,7 +46,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): B = ls.backwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, diff --git a/tests/test_api_fb_haploid_multi.py b/tests/test_api_fb_haploid_multi.py index a50b299..86ad692 100644 --- a/tests/test_api_fb_haploid_multi.py +++ b/tests/test_api_fb_haploid_multi.py @@ -8,14 +8,14 @@ class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): - num_alleles = core.get_num_alleles(H_vs, s) F_vs, c_vs, ll_vs = fbh.forwards_ls_hap( n=n, m=m, @@ -36,7 +36,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): F, c, ll = ls.forwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, @@ -45,7 +45,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): B = ls.backwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, diff --git a/tests/test_api_vit_diploid.py b/tests/test_api_vit_diploid.py index 48d67b7..ea6a47c 100644 --- a/tests/test_api_vit_diploid.py +++ b/tests/test_api_vit_diploid.py @@ -8,16 +8,16 @@ class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 2 for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=2, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) - num_alleles = core.get_num_alleles(H_vs, query) V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem( n=n, @@ -30,9 +30,9 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): path_vs = vd.backwards_viterbi_dip(m=m, V_last=V_vs, P=P_vs) phased_path_vs = vd.get_phased_path(n=n, path=path_vs) path, ll = ls.viterbi( - reference_panel=G_vs, - query=s, - num_alleles=num_alleles, + reference_panel=H_vs, + query=query, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, diff --git a/tests/test_api_vit_haploid.py b/tests/test_api_vit_haploid.py index 3019091..1b737a6 100644 --- a/tests/test_api_vit_haploid.py +++ b/tests/test_api_vit_haploid.py @@ -8,14 +8,14 @@ class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): - num_alleles = core.get_num_alleles(H_vs, s) V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n=n, m=m, @@ -28,7 +28,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): path, ll = ls.viterbi( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, diff --git a/tests/test_api_vit_haploid_multi.py b/tests/test_api_vit_haploid_multi.py index 037dc0d..e93e853 100644 --- a/tests/test_api_vit_haploid_multi.py +++ b/tests/test_api_vit_haploid_multi.py @@ -8,14 +8,14 @@ class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): - num_alleles = core.get_num_alleles(H_vs, s) V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n=n, m=m, @@ -29,7 +29,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): path, ll = ls.viterbi( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate,