Skip to content

Commit

Permalink
Allow windows for genetic_relatedness_vector to not span the whole
Browse files Browse the repository at this point in the history
genome
  • Loading branch information
petrelharp committed Sep 26, 2024
1 parent 36c6786 commit e6bbbb6
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 79 deletions.
8 changes: 7 additions & 1 deletion c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,12 @@ test_empty_genetic_relatedness_vector(void)
&ts, num_weights, weights, 1, windows, result, TSK_STAT_NONCENTRED);
CU_ASSERT_EQUAL_FATAL(ret, 0);

windows[0] = 0.5 * tsk_treeseq_get_sequence_length(&ts);
windows[1] = 0.75 * tsk_treeseq_get_sequence_length(&ts);
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 1, windows, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

tsk_treeseq_free(&ts);
free(weights);
free(result);
Expand Down Expand Up @@ -2135,7 +2141,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void)
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[0] = 10;
windows[0] = 12;
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);
Expand Down
118 changes: 78 additions & 40 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -5197,6 +5197,7 @@ tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out)
{
fprintf(out, "Tree position state\n");
fprintf(out, "index = %d\n", (int) self->index);
fprintf(out, "interval = [%f,\t%f)\n", self->interval.left, self->interval.right);
fprintf(
out, "out = start=%d\tstop=%d\n", (int) self->out.start, (int) self->out.stop);
fprintf(
Expand Down Expand Up @@ -9910,7 +9911,8 @@ typedef struct {
tsk_flags_t options;
double *result;
/* tree */
double tree_left;
tsk_tree_position_t tree_pos;
double position;
tsk_size_t num_nodes;
tsk_id_t *parent;
double *x;
Expand All @@ -9926,7 +9928,9 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out

fprintf(out, "Matvec state:\n");
fprintf(out, "options = %d\n", self->options);
fprintf(out, "tree_left = %f\n", self->tree_left);
fprintf(out, "position = %f\n", self->position);
fprintf(out, "tree_pos:\n");
tsk_tree_position_print_state(&self->tree_pos, out);
fprintf(out, "samples = %lld: [", (long long) num_samples);
fprintf(out, "]\n");
fprintf(out, "node\tparent\tx\tv\tw");
Expand All @@ -9952,16 +9956,19 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t
tsk_size_t k;
tsk_id_t u, j;
double *weight_means = tsk_malloc(num_weights * sizeof(*weight_means));
const tsk_size_t num_trees = ts->num_trees;
const double *restrict breakpoints = ts->breakpoints;
tsk_id_t index;

self->ts = ts;
self->tree_left = 0.0;
self->num_weights = num_weights;
self->weights = weights;
self->num_windows = num_windows;
self->windows = windows;
self->options = options;
self->result = result;
self->num_nodes = num_nodes;
self->position = windows[0];

self->parent = tsk_malloc(num_nodes * sizeof(*self->parent));
self->x = tsk_calloc(num_nodes, sizeof(*self->x));
Expand All @@ -9977,6 +9984,20 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t
tsk_memset(result, 0, num_windows * num_samples * num_weights * sizeof(*result));
tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent));

ret = tsk_tree_position_init(&self->tree_pos, ts, 0);
if (ret != 0) {
goto out;

Check warning on line 9989 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L9989

Added line #L9989 was not covered by tests
}
/* seek to the first window */
index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, windows[0]);
if (breakpoints[index] > windows[0]) {
index--;
}
ret = tsk_tree_position_seek_forward(&self->tree_pos, index);
if (ret != 0) {
goto out;

Check warning on line 9998 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L9998

Added line #L9998 was not covered by tests
}

for (k = 0; k < num_weights; k++) {
weight_means[k] = 0.0;
}
Expand Down Expand Up @@ -10012,14 +10033,15 @@ tsk_matvec_calculator_free(tsk_matvec_calculator_t *self)
tsk_safe_free(self->x);
tsk_safe_free(self->w);
tsk_safe_free(self->v);
tsk_tree_position_free(&self->tree_pos);

/* Make this safe for multiple free calls */
memset(self, 0, sizeof(*self));
return 0;
}

static inline void
tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double tree_left,
tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double position,
double *restrict x, const tsk_size_t num_weights, double *restrict w,
double *restrict v, const double *restrict nodes_time)
{
Expand All @@ -10029,15 +10051,15 @@ tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double tree_left,

if (p != TSK_NULL) {
t = nodes_time[p] - nodes_time[u];
span = tree_left - x[u];
span = position - x[u];
// do this: self->v[u] += t * span * self->w[u];
w_row = GET_2D_ROW(w, num_weights, u);
v_row = GET_2D_ROW(v, num_weights, u);
for (j = 0; j < num_weights; j++) {
v_row[j] += t * span * w_row[j];
}
}
x[u] = tree_left;
x[u] = position;
}

static void
Expand All @@ -10047,7 +10069,7 @@ tsk_matvec_calculator_adjust_path_up(
tsk_size_t j;
double *p_row, *c_row;
const tsk_id_t *restrict parent = self->parent;
const double tree_left = self->tree_left;
const double position = self->position;
double *restrict x = self->x;
const tsk_size_t num_weights = self->num_weights;
double *restrict w = self->w;
Expand All @@ -10057,7 +10079,7 @@ tsk_matvec_calculator_adjust_path_up(
// sign = -1 for removing edges, +1 for adding
while (p != TSK_NULL) {
tsk_matvec_calculator_add_z(
p, parent[p], tree_left, x, num_weights, w, v, nodes_time);
p, parent[p], position, x, num_weights, w, v, nodes_time);
// do this: self->v[c] -= sign * self->v[p];
p_row = GET_2D_ROW(v, num_weights, p);
c_row = GET_2D_ROW(v, num_weights, c);
Expand All @@ -10078,15 +10100,15 @@ static void
tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c)
{
tsk_id_t *parent = self->parent;
const double tree_left = self->tree_left;
const double position = self->position;
double *restrict x = self->x;
const tsk_size_t num_weights = self->num_weights;
double *restrict w = self->w;
double *restrict v = self->v;
const double *restrict nodes_time = self->ts->tables->nodes.time;

tsk_matvec_calculator_add_z(
c, parent[c], tree_left, x, num_weights, w, v, nodes_time);
c, parent[c], position, x, num_weights, w, v, nodes_time);
parent[c] = TSK_NULL;
tsk_matvec_calculator_adjust_path_up(self, p, c, -1);
}
Expand All @@ -10097,7 +10119,7 @@ tsk_matvec_calculator_insert_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk
tsk_id_t *parent = self->parent;

tsk_matvec_calculator_adjust_path_up(self, p, c, +1);
self->x[c] = self->tree_left;
self->x[c] = self->position;
parent[c] = p;
}

Expand All @@ -10109,7 +10131,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri
tsk_size_t j, k;
tsk_size_t n = tsk_treeseq_get_num_samples(self->ts);
const tsk_size_t num_weights = self->num_weights;
const double tree_left = self->tree_left;
const double position = self->position;
double *u_row, *out_row;
double *out_means = tsk_malloc(num_weights * sizeof(*out_means));
const tsk_id_t *restrict parent = self->parent;
Expand All @@ -10128,9 +10150,9 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri
out_row = GET_2D_ROW(y, num_weights, j);
u = samples[j];
while (u != TSK_NULL) {
if (x[u] != tree_left) {
if (x[u] != position) {
tsk_matvec_calculator_add_z(
u, parent[u], tree_left, x, num_weights, w, v, nodes_time);
u, parent[u], position, x, num_weights, w, v, nodes_time);
}
u_row = GET_2D_ROW(v, num_weights, u);
for (k = 0; k < num_weights; k++) {
Expand Down Expand Up @@ -10174,48 +10196,64 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self)
tsk_size_t j, k, m;
tsk_id_t e, p, c;
tsk_size_t n = tsk_treeseq_get_num_samples(self->ts);
double tree_right;
const double sequence_length = self->ts->tables->sequence_length;
const tsk_size_t num_edges = self->ts->tables->edges.num_rows;
const tsk_id_t *restrict I = self->ts->tables->indexes.edge_insertion_order;
const tsk_id_t *restrict O = self->ts->tables->indexes.edge_removal_order;
const double *restrict edge_right = self->ts->tables->edges.right;
const double *restrict edge_left = self->ts->tables->edges.left;
const tsk_id_t *restrict edge_child = self->ts->tables->edges.child;
const tsk_id_t *restrict edge_parent = self->ts->tables->edges.parent;
const double *restrict windows = self->windows;
double *restrict out;
tsk_tree_position_t tree_pos = self->tree_pos;
const tsk_id_t *restrict in_order = tree_pos.in.order;
const tsk_id_t *restrict out_order = tree_pos.out.order;
bool valid;
double next_position;

j = 0;
k = 0;
m = 0;
tree_right = sequence_length;
self->position = windows[0];

while (
m < self->num_windows && k < num_edges && self->tree_left <= sequence_length) {
while (k < num_edges && edge_right[O[k]] == self->tree_left) {
e = O[k];
p = edge_parent[e];
c = edge_child[e];
tsk_matvec_calculator_remove_edge(self, p, c);
k++;
}
while (j < num_edges && edge_left[I[j]] == self->tree_left) {
e = I[j];
for (j = (tsk_size_t) tree_pos.in.start; j != (tsk_size_t) tree_pos.in.stop; j++) {
e = in_order[j];
// note that edge_left[e] <= self->position is always true
// since we are starting from the empty sequence
if (self->position < edge_right[e]) {
p = edge_parent[e];
c = edge_child[e];
tsk_matvec_calculator_insert_edge(self, p, c);
self->x[c] = self->tree_left;
j++;
}
tree_right = self->windows[m + 1];
}

valid = tsk_tree_position_next(&tree_pos);
j = (tsk_size_t) tree_pos.in.start;
k = (tsk_size_t) tree_pos.out.start;
while (m < self->num_windows) {
if (valid && self->position == tree_pos.interval.left) {
for (k = (tsk_size_t) tree_pos.out.start;
k != (tsk_size_t) tree_pos.out.stop; k++) {
e = out_order[k];
p = edge_parent[e];
c = edge_child[e];
tsk_matvec_calculator_remove_edge(self, p, c);
}
for (j = (tsk_size_t) tree_pos.in.start; j != (tsk_size_t) tree_pos.in.stop;
j++) {
e = in_order[j];
p = edge_parent[e];
c = edge_child[e];
tsk_matvec_calculator_insert_edge(self, p, c);
}
valid = tsk_tree_position_next(&tree_pos);
}
next_position = windows[m + 1];
if (j < num_edges) {
tree_right = TSK_MIN(tree_right, edge_left[I[j]]);
next_position = TSK_MIN(next_position, edge_left[in_order[j]]);
}
if (k < num_edges) {
tree_right = TSK_MIN(tree_right, edge_right[O[k]]);
next_position = TSK_MIN(next_position, edge_right[out_order[k]]);
}
self->tree_left = tree_right;
if (self->tree_left == self->windows[m + 1]) {
tsk_bug_assert(self->position < next_position);
self->position = next_position;
if (self->position == windows[m + 1]) {
out = GET_2D_ROW(self->result, self->num_weights * n, m);
tsk_matvec_calculator_write_output(self, out);
m += 1;
Expand Down Expand Up @@ -10245,7 +10283,7 @@ tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num
ret = TSK_ERR_UNSUPPORTED_STAT_MODE;
goto out;
}
ret = tsk_treeseq_check_windows(self, num_windows, windows, TSK_REQUIRE_FULL_SPAN);
ret = tsk_treeseq_check_windows(self, num_windows, windows, 0);
if (ret != 0) {
goto out;
}
Expand Down
2 changes: 0 additions & 2 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2765,12 +2765,10 @@ def test_window_errors(self):
L = ts.get_sequence_length()
bad_windows = [
[L, 0],
[0.1, L],
[-1, L],
[0, L + 0.1],
[0, 0.1, 0.1, L],
[0, -1, L],
[0, 0.1, 0.05, 0.2, L],
]
for bad_window in bad_windows:
with pytest.raises(_tskit.LibraryError):
Expand Down
Loading

0 comments on commit e6bbbb6

Please sign in to comment.