diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index da539552f5..d274ac8203 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2045,10 +2045,10 @@ test_paper_ex_two_site(void) double truth_two_sets[18] = { 1, 1, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 1, 1, 1, 1, 0.1111111111111111, 0.1111111111111111, 1, 1, 1, 1 }; - double truth_three_sets[27] - = { 1, 1, 0, 0.1111111111111111, 0.1111111111111111, 0, 0.1111111111111111, - 0.1111111111111111, 0, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, - 1, 1, 1, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, 1, 1, 1 }; + double truth_three_sets[27] = { 1, 1, NAN, 0.1111111111111111, 0.1111111111111111, + NAN, 0.1111111111111111, 0.1111111111111111, NAN, 0.1111111111111111, + 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1, 0.1111111111111111, + 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1 }; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); @@ -2104,7 +2104,8 @@ test_paper_ex_two_site(void) row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_three_sets); + assert_arrays_almost_equal_nan( + result_size * num_sample_sets, result, truth_three_sets); tsk_treeseq_free(&ts); tsk_safe_free(row_sites); diff --git a/c/tests/testlib.h b/c/tests/testlib.h index 4885fba356..df1dcdef4a 100644 --- a/c/tests/testlib.h +++ b/c/tests/testlib.h @@ -65,6 +65,22 @@ void unsort_edges(tsk_edge_table_t *edges, size_t start); } while (0); \ } +/* Array equality if the arrays contain NaN values + NB: the float cast for NaNs is for mingw, which complains without */ +#define assert_arrays_almost_equal_nan(len, a, b) \ + { \ + do { \ + tsk_size_t _j; \ + for (_j = 0; _j < len; _j++) { \ + if (isnan((float) a[_j]) || isnan((float) b[_j])) { \ + CU_ASSERT_EQUAL_FATAL(isnan((float) a[_j]), isnan((float) b[_j])); \ + } else { \ + CU_ASSERT_DOUBLE_EQUAL(a[_j], b[_j], 1e-9); \ + } \ + } \ + } while (0); \ + } + extern const char *single_tree_ex_nodes; extern const char *single_tree_ex_edges; extern const char *single_tree_ex_sites; diff --git a/c/tskit/trees.c b/c/tskit/trees.c index c9d337bda4..3b343db780 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3597,11 +3597,7 @@ r2_summary_func(tsk_size_t state_dim, const double *state, double D = p_AB - (p_A * p_B); double denom = p_A * p_B * (1 - p_A) * (1 - p_B); - if (denom == 0 && D == 0) { - result[j] = 0; - } else { - result[j] = (D * D) / denom; - } + result[j] = (D * D) / denom; } return 0; } @@ -3637,8 +3633,8 @@ D_prime_summary_func(tsk_size_t state_dim, const double *state, double p_B = p_AB + p_aB; double D = p_AB - (p_A * p_B); - result[j] = 0; - if (D > 0) { + + if (D >= 0) { result[j] = D / TSK_MIN(p_A * (1 - p_B), (1 - p_A) * p_B); } else if (D < 0) { result[j] = D / TSK_MIN(p_A * p_B, (1 - p_A) * (1 - p_B)); @@ -3681,11 +3677,7 @@ r_summary_func(tsk_size_t state_dim, const double *state, double D = p_AB - (p_A * p_B); double denom = p_A * p_B * (1 - p_A) * (1 - p_B); - if (denom == 0 && D == 0) { - result[j] = 0; - } else { - result[j] = D / sqrt(denom); - } + result[j] = D / sqrt(denom); } return 0; } diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 13e2cdf1ea..58b52debb8 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,6 +22,7 @@ """ Test cases for two-locus statistics """ +import contextlib import io from itertools import combinations_with_replacement from itertools import permutations @@ -40,6 +41,12 @@ from tests.test_highlevel import get_example_tree_sequences +@contextlib.contextmanager +def suppress_division_by_zero_warning(): + with np.errstate(invalid="ignore", divide="ignore"): + yield + + class BitSet: """BitSet object, which stores values in arrays of unsigned integers. The rows represent all possible values a bit can take, and the rows @@ -729,9 +736,7 @@ def r2_summary_func( D = p_AB - (p_A * p_B) denom = p_A * p_B * (1 - p_A) * (1 - p_B) - if denom == 0 and D == 0: - result[k] = 0 - else: + with suppress_division_by_zero_warning(): result[k] = (D * D) / denom @@ -782,12 +787,11 @@ def D_prime_summary_func( p_B = p_AB + p_aB D = p_AB - (p_A * p_B) - if D == 0: - result[k] = 0 - elif D > 0: - result[k] = D / min(p_A * (1 - p_B), (1 - p_A) * p_B) - else: - result[k] = D / min(p_A * p_B, (1 - p_A) * (1 - p_B)) + with suppress_division_by_zero_warning(): + if D >= 0: + result[k] = D / min(p_A * (1 - p_B), (1 - p_A) * p_B) + else: + result[k] = D / min(p_A * p_B, (1 - p_A) * (1 - p_B)) def r_summary_func( @@ -806,9 +810,7 @@ def r_summary_func( D = p_AB - (p_A * p_B) denom = p_A * p_B * (1 - p_A) * (1 - p_B) - if denom == 0 and D == 0: - result[k] = 0 - else: + with suppress_division_by_zero_warning(): result[k] = D / np.sqrt(denom)