diff --git a/python/tests/beagle.py b/python/tests/beagle.py index fa826cb48a..5427f1b0db 100644 --- a/python/tests/beagle.py +++ b/python/tests/beagle.py @@ -167,6 +167,12 @@ def compute_forward_probability_matrix(ref_h, query_h, rho, mu): :return: Forward probability matrix. :rtype: numpy.ndarray """ + assert np.all( + np.isin(ref_h, [0, 1]) + ), "Reference haplotypes have invalid values (not 0/1)." + assert np.all( + np.isin(query_h, [0, 1]) + ), "Query haplotype has invalid values (not 0/1)." m = ref_h.shape[0] h = ref_h.shape[1] assert m == len(query_h) @@ -220,6 +226,12 @@ def compute_backward_probability_matrix(ref_h, query_h, rho, mu): :return: Backward probability matrix. :rtype: numpy.ndarray """ + assert np.all( + np.isin(ref_h, [0, 1]) + ), "Reference haplotypes have invalid values (not 0/1)." + assert np.all( + np.isin(query_h, [0, 1]) + ), "Query haplotype has invalid values (not 0/1)." m = ref_h.shape[0] h = ref_h.shape[1] assert m == len(query_h) @@ -411,14 +423,19 @@ def compute_state_probability_matrix(fm, bm, ref_h, query_h): :return: HMM state probability matrix. :rtype: numpy.ndarray """ + assert np.all( + np.isin(ref_h, [0, 1]) + ), "Reference haplotypes have non-biallelic values." + assert np.all(np.isin(query_h, [0, 1])), "Query haplotype has non-biallelic values." m = ref_h.shape[0] h = ref_h.shape[1] assert m == len(query_h) - assert fm.shape == (m, h) - assert bm.shape == (m, h) - # Check all biallelic sites - assert np.all(np.isin(np.unique(ref_h), [0, 1])) - assert np.all(np.isin(np.unique(query_h), [-1, 0, 1])) + assert (m, h) == fm.shape + assert (m, h) == bm.shape + assert np.any(fm < 0), "Forward probability matrix has negative values." + assert np.any(np.isnan(fm)), "Forward probability matrix has NaN values." + assert np.any(bm < 0), "Backward probability matrix has negative values." + assert np.any(np.isnan(bm)), "Backward probability matrix has NaN values." sm = np.zeros((m, h), dtype=np.float64) for i in np.arange(m - 1, -1, -1): for j in np.arange(h):