Skip to content

Commit

Permalink
relatendess matrix-vector operation (rebase)
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Sep 19, 2024
1 parent dcee409 commit 39c4c14
Show file tree
Hide file tree
Showing 7 changed files with 915 additions and 0 deletions.
29 changes: 29 additions & 0 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,33 @@ test_paper_ex_genetic_relatedness_weighted_errors(void)
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_vector(void)
{
int ret;
tsk_treeseq_t ts;
tsk_size_t num_samples;
double *weights, *result;
tsk_size_t j;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);
num_samples = tsk_treeseq_get_num_samples(&ts);

weights = tsk_malloc(num_samples * sizeof(double));
result = tsk_malloc(num_samples * sizeof(double));
for (j = 0; j < num_samples; j++) {
weights[j] = 1.0;
}

ret = tsk_treeseq_genetic_relatedness_vector(&ts, 1, weights, 0, NULL, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

tsk_treeseq_free(&ts);
free(weights);
free(result);
}

static void
test_paper_ex_Y2_errors(void)
{
Expand Down Expand Up @@ -3532,6 +3559,8 @@ main(int argc, char **argv)
test_paper_ex_genetic_relatedness_weighted },
{ "test_paper_ex_genetic_relatedness_weighted_errors",
test_paper_ex_genetic_relatedness_weighted_errors },
{ "test_paper_ex_genetic_relatedness_vector",
test_paper_ex_genetic_relatedness_vector },
{ "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors },
{ "test_paper_ex_Y2", test_paper_ex_Y2 },
{ "test_paper_ex_f2_errors", test_paper_ex_f2_errors },
Expand Down
285 changes: 285 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -9896,3 +9896,288 @@ tsk_treeseq_pair_coalescence_rates(const tsk_treeseq_t *self, tsk_size_t num_sam
out:
return ret;
}

/* ======================================================== *
* Relatedness matrix-vector product
* ======================================================== */

typedef struct {
const tsk_treeseq_t *ts;
tsk_size_t num_weights;
const double *weights;
tsk_flags_t options;
double *result;
/* tree */
double tree_left;
tsk_id_t virtual_root;
tsk_size_t num_nodes;
tsk_id_t *parent;
double *x;
double *w;
double *v;
} tsk_matvec_calculator_t;

static void
tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out)
{
tsk_id_t j, u;
tsk_size_t num_samples = tsk_treeseq_get_num_samples(self->ts);

fprintf(out, "Matvec state:\n");
fprintf(out, "options = %d\n", self->options);
fprintf(out, "tree_left = %f\n", self->tree_left);
fprintf(out, "samples = %lld: [", (long long) num_samples);
fprintf(out, "]\n");
fprintf(out, "node\tparent\tx\tv\tw");
fprintf(out, "\n");

for (j = 0; j < (tsk_id_t) self->num_nodes; j++) {
if (j < self->virtual_root) {
fprintf(out, "%lld\t", (long long) j);
} else if (j == self->virtual_root) {
fprintf(out, "VR:%lld\t", (long long) j);
} else {
u = self->ts->samples[j - self->virtual_root - 1];
fprintf(out, "%lld(%lld)\t", (long long) j, (long long) u);
}
fprintf(out, "%lld\t%g\t%g\t%g\n", (long long) self->parent[j], self->x[j],
self->v[j], self->w[j]);
}
}

static int
tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *ts,
tsk_size_t num_weights, const double *weights, tsk_flags_t options, double *result)
{
int ret = 0;
tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts);
const tsk_size_t num_nodes = ts->tables->nodes.num_rows + num_samples + 1;
const double *row;
double *new_row;
tsk_size_t k;
tsk_id_t u, v, j;

self->ts = ts;
self->tree_left = 0.0;
self->num_weights = num_weights;
self->weights = weights;
self->options = options;
self->result = result;
self->num_nodes = num_nodes;
self->virtual_root = (tsk_id_t) ts->tables->nodes.num_rows;

self->parent = tsk_malloc(num_nodes * sizeof(*self->parent));
self->x = tsk_calloc(num_nodes, sizeof(*self->x));
self->v = tsk_calloc(num_nodes, num_weights * sizeof(*self->v));
self->w = tsk_calloc(num_nodes, num_weights * sizeof(*self->w));

if (self->parent == NULL || self->x == NULL || self->w == NULL || self->v == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

tsk_memset(result, 0, num_samples * sizeof(*result));
tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent));

for (j = 0; j < (tsk_id_t) num_samples; j++) {
u = ts->samples[j];
row = GET_2D_ROW(weights, num_weights, j);
new_row = GET_2D_ROW(self->w, num_weights, u);
for (k = 0; k < num_weights; j++) {
new_row[k] = row[k];
}
// add branch to the virtual sample
v = self->virtual_root + 1 + j;
self->parent[v] = u;
}
out:
return ret;
}

static int
tsk_matvec_calculator_free(tsk_matvec_calculator_t *self)
{
tsk_safe_free(self->parent);
tsk_safe_free(self->x);
tsk_safe_free(self->w);
tsk_safe_free(self->v);

/* Make this safe for multiple free calls */
memset(self, 0, sizeof(*self));
return 0;
}

static void
tsk_matvec_calculator_add_z(const tsk_matvec_calculator_t *self, tsk_id_t u)
{
const tsk_id_t p = self->parent[u];
const double *restrict nodes_time = self->ts->tables->nodes.time;
double t, span;
tsk_size_t j;
double *v_row, *w_row;

if (p != TSK_NULL && u < self->virtual_root) {
t = nodes_time[p] - nodes_time[u];
span = self->tree_left - self->x[u];
// do this: self->v[u] += t * span * self->w[u];
w_row = GET_2D_ROW(self->w, self->num_weights, u);
v_row = GET_2D_ROW(self->v, self->num_weights, u);
for (j = 0; j < self->num_weights; j++) {
v_row[j] += t * span * w_row[j];
}
}
self->x[u] = self->tree_left;
}

static void
tsk_matvec_calculator_adjust_path_up(
tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c, double sign)
{
tsk_size_t j;
double *p_row, *c_row;

// sign = -1 for removing edges, +1 for adding
while (p != TSK_NULL) {
tsk_matvec_calculator_add_z(self, p);
// do this: self->v[c] -= sign * self->v[p];
p_row = GET_2D_ROW(self->v, self->num_weights, p);
c_row = GET_2D_ROW(self->v, self->num_weights, c);
for (j = 0; j < self->num_weights; j++) {
c_row[j] -= sign * p_row[j];
}
// do this: self->w[p] += sign * self->w[c];
p_row = GET_2D_ROW(self->w, self->num_weights, p);
c_row = GET_2D_ROW(self->w, self->num_weights, c);
for (j = 0; j < self->num_weights; j++) {
p_row[j] += sign * c_row[j];
}
p = self->parent[p];
}
}

static void
tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c)
{
tsk_id_t *restrict parent = self->parent;

tsk_matvec_calculator_add_z(self, c);
parent[c] = TSK_NULL;
tsk_matvec_calculator_adjust_path_up(self, p, c, -1);
}

static void
tsk_matvec_calculator_insert_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c)
{
tsk_id_t *restrict parent = self->parent;

tsk_matvec_calculator_adjust_path_up(self, p, c, +1);
self->x[c] = self->tree_left;
parent[c] = p;
}

static void
tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self)
{
tsk_id_t u, v;
tsk_size_t j, k;
tsk_size_t n = tsk_treeseq_get_num_samples(self->ts);
double *restrict y = self->result;
double *v_row, *out_row;

for (j = 0; j < n; j++) {
u = self->ts->samples[j];
v = self->virtual_root + 1 + (tsk_id_t) j;
tsk_bug_assert(u == self->parent[v]);
tsk_matvec_calculator_remove_edge(self, u, v);
}
for (j = 0; j < n; j++) {
v = self->virtual_root + 1 + (tsk_id_t) j;
v_row = GET_2D_ROW(self->v, self->num_weights, v);
out_row = GET_2D_ROW(y, self->num_weights, v);
for (k = 0; k < self->num_weights; k++) {
out_row[k] = v_row[k];
}
}
}

static int
tsk_matvec_calculator_run(tsk_matvec_calculator_t *self)
{
int ret = 0;
tsk_size_t j, k;
tsk_id_t e, p, c;
double tree_right;
const double sequence_length = self->ts->tables->sequence_length;
const tsk_size_t num_edges = self->ts->tables->edges.num_rows;
const tsk_id_t *restrict I = self->ts->tables->indexes.edge_insertion_order;
const tsk_id_t *restrict O = self->ts->tables->indexes.edge_removal_order;
const double *restrict edge_right = self->ts->tables->edges.right;
const double *restrict edge_left = self->ts->tables->edges.left;
const tsk_id_t *restrict edge_child = self->ts->tables->edges.child;
const tsk_id_t *restrict edge_parent = self->ts->tables->edges.parent;

tree_right = sequence_length;
j = 0;
k = 0;

while (k < num_edges || self->tree_left < sequence_length) {
while (k < num_edges && edge_right[O[k]] == self->tree_left) {
e = O[k];
p = edge_parent[e];
c = edge_child[e];
tsk_matvec_calculator_remove_edge(self, p, c);
k++;
}
while (j < num_edges && edge_left[I[j]] == self->tree_left) {
e = I[j];
p = edge_parent[e];
c = edge_child[e];
tsk_matvec_calculator_insert_edge(self, p, c);
self->x[c] = self->tree_left;
j++;
}
tree_right = sequence_length;
if (j < num_edges) {
tree_right = TSK_MIN(tree_right, edge_left[I[j]]);
}
if (k < num_edges) {
tree_right = TSK_MIN(tree_right, edge_right[O[k]]);
}
self->tree_left = tree_right;
if (self->options & TSK_DEBUG) {
tsk_matvec_calculator_print_state(self, stdout);
}
}
/* tsk_matvec_calculator_print_state(self, stdout); */
tsk_matvec_calculator_write_output(self);

/* out: */
return ret;
}

int
tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights,
const double *weights, tsk_size_t num_windows, const double *windows, double *result,
tsk_flags_t options)
{
int ret = 0;
tsk_matvec_calculator_t calc;

// TODO add windows
tsk_bug_assert(num_windows == 0);
tsk_bug_assert(windows == NULL);

memset(&calc, 0, sizeof(calc));

ret = tsk_matvec_calculator_init(&calc, self, num_weights, weights, options, result);
if (ret != 0) {
goto out;
}
if (options & TSK_DEBUG) {
tsk_matvec_calculator_print_state(&calc, tsk_get_debug_stream());
}
ret = tsk_matvec_calculator_run(&calc);
out:
tsk_matvec_calculator_free(&calc);
return ret;
}
10 changes: 10 additions & 0 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,16 @@ int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options);

/* One way weighted stats with vector output */

typedef int weighted_vector_method(const tsk_treeseq_t *self, tsk_size_t num_weights,
const double *weights, tsk_size_t num_windows, const double *windows, double *result,
tsk_flags_t options);

int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_windows,
const double *windows, double *result, tsk_flags_t options);

/* One way sample set stats */

typedef int one_way_sample_stat_method(const tsk_treeseq_t *self,
Expand Down
1 change: 1 addition & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ Single site
TreeSequence.genealogical_nearest_neighbours
TreeSequence.genetic_relatedness
TreeSequence.genetic_relatedness_weighted
TreeSequence.genetic_relatedness_vector
TreeSequence.general_stat
TreeSequence.segregating_sites
TreeSequence.sample_count_stat
Expand Down
Loading

0 comments on commit 39c4c14

Please sign in to comment.