Skip to content

Commit

Permalink
allow more than one column of weights
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Sep 19, 2024
1 parent 86b87d7 commit bb5b603
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 11 deletions.
15 changes: 12 additions & 3 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -2000,18 +2000,27 @@ test_paper_ex_genetic_relatedness_vector(void)
tsk_size_t num_samples;
double *weights, *result;
tsk_size_t j;
tsk_size_t num_weights = 2;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);
num_samples = tsk_treeseq_get_num_samples(&ts);

weights = tsk_malloc(num_samples * sizeof(double));
result = tsk_malloc(num_samples * sizeof(double));
weights = tsk_malloc(num_weights * num_samples * sizeof(double));
result = tsk_malloc(num_weights * num_samples * sizeof(double));
for (j = 0; j < num_samples; j++) {
weights[j] = 1.0;
}
for (j = 0; j < num_samples; j++) {
weights[j + num_samples] = (float) j;
}

ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 0, NULL, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_genetic_relatedness_vector(&ts, 1, weights, 0, NULL, result, 0);
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 0, NULL, result, TSK_STAT_NONCENTRED);
CU_ASSERT_EQUAL_FATAL(ret, 0);

tsk_treeseq_free(&ts);
Expand Down
9 changes: 6 additions & 3 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -9976,14 +9976,16 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t
goto out;

Check warning on line 9976 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L9975-L9976

Added lines #L9975 - L9976 were not covered by tests
}

tsk_memset(result, 0, num_samples * sizeof(*result));
tsk_memset(result, 0, num_samples * num_weights * sizeof(*result));
tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent));

// TODO: if centred then mean-centre the sample weights

for (j = 0; j < (tsk_id_t) num_samples; j++) {
u = ts->samples[j];
row = GET_2D_ROW(weights, num_weights, j);
new_row = GET_2D_ROW(self->w, num_weights, u);
for (k = 0; k < num_weights; j++) {
for (k = 0; k < num_weights; k++) {
new_row[k] = row[k];
}
// add branch to the virtual sample
Expand Down Expand Up @@ -10093,11 +10095,12 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self)
for (j = 0; j < n; j++) {
v = self->virtual_root + 1 + (tsk_id_t) j;
v_row = GET_2D_ROW(self->v, self->num_weights, v);
out_row = GET_2D_ROW(y, self->num_weights, v);
out_row = GET_2D_ROW(y, self->num_weights, j);
for (k = 0; k < self->num_weights; k++) {
out_row[k] = v_row[k];
}
}
// TODO: if centred then mean-centre the output
}

static int
Expand Down
4 changes: 2 additions & 2 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9697,8 +9697,8 @@ TreeSequence_weighted_stat_vector_method(
goto out;
}

result_array
= TreeSequence_allocate_results_array(self, options, num_windows, num_samples);
result_array = TreeSequence_allocate_results_array(
self, options, num_windows, num_samples * w_shape[1]);
if (result_array == NULL) {
goto out;
}
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_relatedness_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,15 @@ def verify_relatedness_vector(
)
Sigma = relatedness_matrix(ts, centre=centre)
R2 = Sigma.dot(w)
# R3 = ts.genetic_relatedness_vector(w, mode="branch", centre=centre)
R3 = ts.genetic_relatedness_vector(w, mode="branch", centre=centre)
if verbosity > 0:
print(ts.draw_text())
print("weights:", w)
print("here:", R1)
print("with ts:", R2)
print("Sigma:", Sigma)
np.testing.assert_allclose(R1, R2, atol=1e-14)
# np.testing.assert_allclose(R1, R3)
np.testing.assert_allclose(R1, R3)
return R1


Expand Down
8 changes: 7 additions & 1 deletion python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -8460,14 +8460,20 @@ def genetic_relatedness_vector(
raise ValueError(
"First trait dimension must be equal to number of samples."
)
return self.__weighted_vector_stat(
# TODO: this should happen in C
if centre:
W = np.array(W - np.mean(W, axis=0))
out = self.__weighted_vector_stat(
self._ll_tree_sequence.genetic_relatedness_vector,
W,
windows=windows,
mode=mode,
span_normalise=span_normalise,
centre=centre,
)
if centre:
out -= np.mean(out, axis=0)
return out

def trait_covariance(self, W, windows=None, mode="site", span_normalise=True):
"""
Expand Down

0 comments on commit bb5b603

Please sign in to comment.