Skip to content

Commit

Permalink
Improve error handling on weighted stats
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher authored and mergify[bot] committed Jul 13, 2023
1 parent e20a0a2 commit 60d75c6
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 34 deletions.
84 changes: 72 additions & 12 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,16 @@ verify_window_errors(tsk_treeseq_t *ts, tsk_flags_t mode)
ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[0] = -1;
ret = tsk_treeseq_general_stat(
ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[1] = -1;
ret = tsk_treeseq_general_stat(
ts, 1, W, 1, general_stat_error, NULL, 1, windows, options, sigma);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[0] = 10;
ret = tsk_treeseq_general_stat(
ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma);
Expand Down Expand Up @@ -438,11 +448,10 @@ verify_node_general_stat_errors(tsk_treeseq_t *ts)
static void
verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method *method)
{
// we don't have any specific errors for this function
// but we might add some in the future
int ret;
tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts);
double *weights = tsk_malloc(num_samples * sizeof(double));
double bad_windows[] = { 0, -1 };
double result;
tsk_size_t j;

Expand All @@ -451,7 +460,10 @@ verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method *
}

ret = method(ts, 0, weights, 0, NULL, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS);

ret = method(ts, 1, weights, 1, bad_windows, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

free(weights);
}
Expand All @@ -460,12 +472,11 @@ static void
verify_one_way_weighted_covariate_func_errors(
tsk_treeseq_t *ts, one_way_covariates_method *method)
{
// we don't have any specific errors for this function
// but we might add some in the future
int ret;
tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts);
double *weights = tsk_malloc(num_samples * sizeof(double));
double *covariates = NULL;
double bad_windows[] = { 0, -1 };
double result;
tsk_size_t j;

Expand All @@ -474,7 +485,10 @@ verify_one_way_weighted_covariate_func_errors(
}

ret = method(ts, 0, weights, 0, covariates, 0, NULL, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS);

ret = method(ts, 1, weights, 0, covariates, 1, bad_windows, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

free(weights);
}
Expand Down Expand Up @@ -558,6 +572,28 @@ verify_two_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *m
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX);
}

static void
verify_two_way_weighted_stat_func_errors(
tsk_treeseq_t *ts, two_way_weighted_method *method)
{
int ret;
tsk_id_t indexes[] = { 0, 0, 0, 1 };
double bad_windows[] = { -1, -1 };
double weights[10];
double result[10];

memset(weights, 0, sizeof(weights));

ret = method(ts, 2, weights, 2, indexes, 0, NULL, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = method(ts, 0, weights, 2, indexes, 0, NULL, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS);

ret = method(ts, 2, weights, 2, indexes, 1, bad_windows, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);
}

static void
verify_three_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *method)
{
Expand Down Expand Up @@ -1504,32 +1540,54 @@ test_paper_ex_genetic_relatedness(void)
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_errors(void)
{
tsk_treeseq_t ts;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);
verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness);
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_weighted(void)
{
tsk_treeseq_t ts;
double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 };
tsk_id_t indexes[] = { 0, 0, 0, 1 };
double result[2];
double result[100];
tsk_size_t num_weights;
int ret;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);

ret = tsk_treeseq_genetic_relatedness_weighted(
&ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE);
CU_ASSERT_EQUAL_FATAL(ret, 0);
for (num_weights = 1; num_weights < 3; num_weights++) {
ret = tsk_treeseq_genetic_relatedness_weighted(
&ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_treeseq_genetic_relatedness_weighted(
&ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_treeseq_genetic_relatedness_weighted(
&ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_NODE);
CU_ASSERT_EQUAL_FATAL(ret, 0);
}

tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_errors(void)
test_paper_ex_genetic_relatedness_weighted_errors(void)
{
tsk_treeseq_t ts;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);
verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness);
verify_two_way_weighted_stat_func_errors(
&ts, tsk_treeseq_genetic_relatedness_weighted);
tsk_treeseq_free(&ts);
}

Expand Down Expand Up @@ -2128,6 +2186,8 @@ main(int argc, char **argv)
{ "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness },
{ "test_paper_ex_genetic_relatedness_weighted",
test_paper_ex_genetic_relatedness_weighted },
{ "test_paper_ex_genetic_relatedness_weighted_errors",
test_paper_ex_genetic_relatedness_weighted_errors },
{ "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors },
{ "test_paper_ex_Y2", test_paper_ex_Y2 },
{ "test_paper_ex_f2_errors", test_paper_ex_f2_errors },
Expand Down
4 changes: 4 additions & 0 deletions c/tskit/core.c
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ tsk_strerror_internal(int err)
"statistic. "
"(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)";
break;
case TSK_ERR_INSUFFICIENT_WEIGHTS:
ret = "Insufficient weights provided (at least 1 required). "
"(TSK_ERR_INSUFFICIENT_WEIGHTS)";
break;

/* Mutation mapping errors */
case TSK_ERR_GENOTYPES_ALL_MISSING:
Expand Down
4 changes: 4 additions & 0 deletions c/tskit/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,10 @@ The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does
not support it.
*/
#define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912
/**
Insufficient weights were provided.
*/
#define TSK_ERR_INSUFFICIENT_WEIGHTS -913
/** @} */

/**
Expand Down
12 changes: 10 additions & 2 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2639,6 +2639,10 @@ tsk_treeseq_trait_covariance(const tsk_treeseq_t *self, tsk_size_t num_weights,
ret = TSK_ERR_NO_MEMORY;
goto out;
}
if (num_weights == 0) {
ret = TSK_ERR_INSUFFICIENT_WEIGHTS;
goto out;
}

// center weights
for (j = 0; j < num_samples; j++) {
Expand Down Expand Up @@ -2710,7 +2714,7 @@ tsk_treeseq_trait_correlation(const tsk_treeseq_t *self, tsk_size_t num_weights,
}

if (num_weights < 1) {
ret = TSK_ERR_BAD_STATE_DIMS;
ret = TSK_ERR_INSUFFICIENT_WEIGHTS;
goto out;
}

Expand Down Expand Up @@ -2823,7 +2827,7 @@ tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_weights
}

if (num_weights < 1) {
ret = TSK_ERR_BAD_STATE_DIMS;
ret = TSK_ERR_INSUFFICIENT_WEIGHTS;
goto out;
}

Expand Down Expand Up @@ -3071,6 +3075,10 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
ret = TSK_ERR_NO_MEMORY;
goto out;
}
if (num_weights == 0) {
ret = TSK_ERR_INSUFFICIENT_WEIGHTS;
goto out;
}

// Add a column of ones to W
for (j = 0; j < num_samples; j++) {
Expand Down
6 changes: 5 additions & 1 deletion python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,8 +2153,12 @@ def test_bad_weights(self):
del params["weights"]
n = ts.get_num_samples()

for bad_weight_type in [None, [None, None]]:
with pytest.raises(ValueError, match="object of too small depth"):
f(weights=bad_weight_type, **params)

for bad_weight_shape in [(n - 1, 1), (n + 1, 1), (0, 3)]:
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="First dimension must be num_samples"):
f(weights=np.ones(bad_weight_shape), **params)

def test_output_dims(self):
Expand Down
Loading

0 comments on commit 60d75c6

Please sign in to comment.