From 8f5e416e89ae9c08816354c25e8421a668da27f1 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 13 Jul 2023 11:09:10 +0100 Subject: [PATCH] Improve error handling on weighted stats --- c/tests/test_stats.c | 84 +++++++++++++++++++++---- c/tskit/core.c | 4 ++ c/tskit/core.h | 4 ++ c/tskit/trees.c | 12 +++- python/tests/test_lowlevel.py | 6 +- python/tests/test_tree_stats.py | 107 +++++++++++++++++++++++++++----- python/tskit/trees.py | 5 +- 7 files changed, 188 insertions(+), 34 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 154c0b6296..39f2a063e0 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -345,6 +345,16 @@ verify_window_errors(tsk_treeseq_t *ts, tsk_flags_t mode) ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + windows[0] = -1; + ret = tsk_treeseq_general_stat( + ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[1] = -1; + ret = tsk_treeseq_general_stat( + ts, 1, W, 1, general_stat_error, NULL, 1, windows, options, sigma); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + windows[0] = 10; ret = tsk_treeseq_general_stat( ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); @@ -438,11 +448,10 @@ verify_node_general_stat_errors(tsk_treeseq_t *ts) static void verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method *method) { - // we don't have any specific errors for this function - // but we might add some in the future int ret; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); double *weights = tsk_malloc(num_samples * sizeof(double)); + double bad_windows[] = { 0, -1 }; double result; tsk_size_t j; @@ -451,7 +460,10 @@ verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method * } ret = method(ts, 0, weights, 0, NULL, 0, &result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 1, weights, 1, bad_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); free(weights); } @@ -460,12 +472,11 @@ static void verify_one_way_weighted_covariate_func_errors( tsk_treeseq_t *ts, one_way_covariates_method *method) { - // we don't have any specific errors for this function - // but we might add some in the future int ret; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); double *weights = tsk_malloc(num_samples * sizeof(double)); double *covariates = NULL; + double bad_windows[] = { 0, -1 }; double result; tsk_size_t j; @@ -474,7 +485,10 @@ verify_one_way_weighted_covariate_func_errors( } ret = method(ts, 0, weights, 0, covariates, 0, NULL, 0, &result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 1, weights, 0, covariates, 1, bad_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); free(weights); } @@ -558,6 +572,28 @@ verify_two_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *m CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); } +static void +verify_two_way_weighted_stat_func_errors( + tsk_treeseq_t *ts, two_way_weighted_method *method) +{ + int ret; + tsk_id_t indexes[] = { 0, 0, 0, 1 }; + double bad_windows[] = { -1, -1 }; + double weights[10]; + double result[10]; + + memset(weights, 0, sizeof(weights)); + + ret = method(ts, 2, weights, 2, indexes, 0, NULL, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = method(ts, 0, weights, 2, indexes, 0, NULL, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 2, weights, 2, indexes, 1, bad_windows, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); +} + static void verify_three_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *method) { @@ -1504,32 +1540,54 @@ test_paper_ex_genetic_relatedness(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_genetic_relatedness_errors(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness); + tsk_treeseq_free(&ts); +} + static void test_paper_ex_genetic_relatedness_weighted(void) { tsk_treeseq_t ts; double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 }; tsk_id_t indexes[] = { 0, 0, 0, 1 }; - double result[2]; + double result[100]; + tsk_size_t num_weights; int ret; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - ret = tsk_treeseq_genetic_relatedness_weighted( - &ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); - CU_ASSERT_EQUAL_FATAL(ret, 0); + for (num_weights = 1; num_weights < 3; num_weights++) { + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_NODE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + tsk_treeseq_free(&ts); } static void -test_paper_ex_genetic_relatedness_errors(void) +test_paper_ex_genetic_relatedness_weighted_errors(void) { tsk_treeseq_t ts; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness); + verify_two_way_weighted_stat_func_errors( + &ts, tsk_treeseq_genetic_relatedness_weighted); tsk_treeseq_free(&ts); } @@ -2128,6 +2186,8 @@ main(int argc, char **argv) { "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness }, { "test_paper_ex_genetic_relatedness_weighted", test_paper_ex_genetic_relatedness_weighted }, + { "test_paper_ex_genetic_relatedness_weighted_errors", + test_paper_ex_genetic_relatedness_weighted_errors }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/core.c b/c/tskit/core.c index 100cc78cad..5a8ed6d9ac 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -475,6 +475,10 @@ tsk_strerror_internal(int err) "statistic. " "(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)"; break; + case TSK_ERR_INSUFFICIENT_WEIGHTS: + ret = "Insufficient weights provided (at least 1 required). " + "(TSK_ERR_INSUFFICIENT_WEIGHTS)"; + break; /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: diff --git a/c/tskit/core.h b/c/tskit/core.h index 4d2c95212d..45a33dd8b7 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -685,6 +685,10 @@ The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does not support it. */ #define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912 +/** +Insufficient weights were provided. +*/ +#define TSK_ERR_INSUFFICIENT_WEIGHTS -913 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index b5cb654a7d..dac3ac154b 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2639,6 +2639,10 @@ tsk_treeseq_trait_covariance(const tsk_treeseq_t *self, tsk_size_t num_weights, ret = TSK_ERR_NO_MEMORY; goto out; } + if (num_weights == 0) { + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; + goto out; + } // center weights for (j = 0; j < num_samples; j++) { @@ -2710,7 +2714,7 @@ tsk_treeseq_trait_correlation(const tsk_treeseq_t *self, tsk_size_t num_weights, } if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; goto out; } @@ -2823,7 +2827,7 @@ tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_weights } if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; goto out; } @@ -3071,6 +3075,10 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, ret = TSK_ERR_NO_MEMORY; goto out; } + if (num_weights == 0) { + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; + goto out; + } // Add a column of ones to W for (j = 0; j < num_samples; j++) { diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index d94d6e9784..c33f159deb 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2153,8 +2153,12 @@ def test_bad_weights(self): del params["weights"] n = ts.get_num_samples() + for bad_weight_type in [None, [None, None]]: + with pytest.raises(ValueError, match="object of too small depth"): + f(weights=bad_weight_type, **params) + for bad_weight_shape in [(n - 1, 1), (n + 1, 1), (0, 3)]: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="First dimension must be num_samples"): f(weights=np.ones(bad_weight_shape), **params) def test_output_dims(self): diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index a06a690483..99e8e11c55 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -453,7 +453,6 @@ def node_summary(u): # contains the location of the last time we updated the output for a node. last_update = np.zeros((ts.num_nodes, 1)) for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): - for edge in edges_out: u = edge.child v = edge.parent @@ -980,7 +979,6 @@ def verify(self, ts): self.verify_weighted_stat(ts, W, windows=windows) def verify_definition(self, ts, W, windows, summary_func, ts_method, definition): - # general_stat will need an extra column for p gW = self.transform_weights(W) @@ -1025,7 +1023,6 @@ def verify(self, ts): def verify_definition( self, ts, sample_sets, windows, summary_func, ts_method, definition ): - W = np.array([[u in A for A in sample_sets] for u in ts.samples()], dtype=float) def wrapped_summary_func(x): @@ -1762,7 +1759,6 @@ def divergence( class TestDivergence(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -1974,7 +1970,6 @@ def genetic_relatedness( class TestGeneticRelatedness(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2035,7 +2030,6 @@ def wrapped_summary_func(x): self.assertArrayAlmostEqual(sigma1, sigma4) def verify_sample_sets_indexes(self, ts, sample_sets, indexes, windows): - n = np.array([len(x) for x in sample_sets]) n_total = sum(n) @@ -2209,14 +2203,12 @@ def example_index_pairs(weights): class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None def verify_definition( self, ts, W, indexes, windows, summary_func, ts_method, definition ): - # Determine output_dim of the function M = len(indexes) @@ -2298,6 +2290,96 @@ class TestSiteGeneticRelatednessWeighted( mode = "site" +# NOTE: these classes don't follow the same (anti)-patterns as used elsewhere as they +# were added in several years afterwards. + + +class TestGeneticRelatednessWeightedSimpleExamples: + # Values verified against the simple implementations above + site_value = 11.12 + branch_value = 14.72 + + def fixture(self): + ts = tskit.Tree.generate_balanced(5).tree_sequence + # Abitrary weights that give non-zero results + W = np.zeros((ts.num_samples, 2)) + W[0, :] = 1 + W[1, :] = 2 + return tsutil.insert_branch_sites(ts), W + + def test_no_arguments_site(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="site") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.site_value) + + def test_windows_site(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="site", windows=[0, 1 - 1e-12, 1]) + assert X.shape == (2,) + nt.assert_almost_equal(X[0], self.site_value) + nt.assert_almost_equal(X[1], 0) + + def test_no_arguments_branch(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="branch") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.branch_value) + + def test_windows_branch(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="branch", windows=[0, 0.5, 1]) + assert X.shape == (2,) + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_1D(self): + ts, W = self.fixture() + indexes = [0, 1] + X = ts.genetic_relatedness_weighted(W, indexes, mode="branch") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_2D(self): + ts, W = self.fixture() + indexes = [[0, 1]] + X = ts.genetic_relatedness_weighted(W, indexes, mode="branch") + assert X.shape == (1,) + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_2D_windows(self): + ts, W = self.fixture() + indexes = [[0, 1], [0, 1]] + X = ts.genetic_relatedness_weighted( + W, indexes, windows=[0, 0.5, 1], mode="branch" + ) + assert X.shape == (2, 2) + nt.assert_almost_equal(X, self.branch_value) + + +class TestGeneticRelatednessWeightedErrors: + def ts(self): + return tskit.Tree.generate_balanced(3).tree_sequence + + @pytest.mark.parametrize("W", [[0], np.array([0]), np.zeros(100)]) + def test_bad_weight_size(self, W): + with pytest.raises(ValueError, match="First trait dimension"): + self.ts().genetic_relatedness_weighted(W) + + @pytest.mark.parametrize("cols", [1, 3]) + def test_no_indexes_with_non_2_cols(self, cols): + ts = self.ts() + W = np.zeros((ts.num_samples, cols)) + with pytest.raises(ValueError, match="Must specify indexes"): + ts.genetic_relatedness_weighted(W) + + @pytest.mark.parametrize("indexes", [[], [[0]], [[0, 0, 0]], [[[0], [0], [0]]]]) + def test_bad_index_shapes(self, indexes): + ts = self.ts() + W = np.zeros((ts.num_samples, 2)) + with pytest.raises(ValueError, match="Indexes must be convertable to a 2D"): + ts.genetic_relatedness_weighted(W, indexes=indexes) + + ############################################ # Fst ############################################ @@ -2340,7 +2422,6 @@ def single_site_Fst(ts, sample_sets, indexes): class TestFst(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2529,7 +2610,6 @@ def Y2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class TestY2(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2702,7 +2782,6 @@ def Y3(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class TestY3(StatsTestCase, ThreeWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2871,7 +2950,6 @@ def f2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf2(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3057,7 +3135,6 @@ def f3(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf3(StatsTestCase, ThreeWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3248,7 +3325,6 @@ def f4(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf4(StatsTestCase, FourWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3512,7 +3588,6 @@ def update_result(window_index, u, right): last_update[u] = right for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): - for edge in edges_out: u = edge.child v = edge.parent @@ -3673,7 +3748,6 @@ def allele_frequency_spectrum( class TestAlleleFrequencySpectrum(StatsTestCase, SampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -6003,7 +6077,6 @@ def f(x): branch_true_diversity_02, ], ): - self.assertAlmostEqual(diversity(ts, A, mode=mode)[0][0], truth) self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0], truth) self.assertAlmostEqual(ts.diversity(A, mode="branch")[0], truth) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 236a45b8b7..10b3bd9782 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7569,10 +7569,11 @@ def __k_way_weighted_stat( span_normalise=True, polarised=False, ): + W = np.asarray(W) if indexes is None: if W.shape[1] != k: raise ValueError( - "Must specify indexes if there are not exactly {} columsn " + "Must specify indexes if there are not exactly {} columns " "in W.".format(k) ) indexes = np.arange(k, dtype=np.int32) @@ -8016,7 +8017,7 @@ def genetic_relatedness_weighted( window (defaults to True). :return: A ndarray with shape equal to (num windows, num statistics). """ - if W.shape[0] != self.num_samples: + if len(W) != self.num_samples: raise ValueError( "First trait dimension must be equal to number of samples." )