Skip to content

Commit

Permalink
Site divmat based on genotype decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 14, 2023
1 parent b6f9872 commit 86e2960
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 135 deletions.
15 changes: 10 additions & 5 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,6 @@ test_single_tree_divergence_matrix_multi_root(void)
int ret;
double result[16];
double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 };
double D_site[16] = { 0, 4, 6, 6, 4, 0, 6, 6, 6, 6, 0, 8, 6, 6, 8, 0 };

const char *nodes = "1 0 -1 -1\n"
"1 0 -1 -1\n" /* 2.00┊ 5 ┊ */
Expand All @@ -1142,7 +1141,7 @@ test_single_tree_divergence_matrix_multi_root(void)
"0 2 -1 -1\n"; /* 0 * * * * 1 */
const char *edges = "0 1 4 0,1\n"
"0 1 5 2,3\n";
/* Two mutations per branch unit so we get twice branch length value */
/* Two mutations per branch */
const char *sites = "0.1 A\n"
"0.2 A\n"
"0.3 A\n"
Expand All @@ -1166,9 +1165,8 @@ test_single_tree_divergence_matrix_multi_root(void)
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(16, result, D_branch);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(16, result, D_site);
verify_divergence_matrix(&ts, TSK_STAT_SITE);
verify_divergence_matrix(&ts, TSK_STAT_BRANCH);

tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -2041,6 +2039,13 @@ test_simplest_divergence_matrix(void)
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);

sample_ids[0] = 1;
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE);
ret = tsk_treeseq_divergence_matrix(
&ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE);

tsk_treeseq_free(&ts);
}

Expand Down
209 changes: 132 additions & 77 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6597,43 +6597,62 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s
return ret;
}

static tsk_size_t
count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent,
const double *restrict time, const tsk_size_t *restrict mutations_per_node)
// FIXME see #2817
// Just including this here for now as it's the simplest option. Everything
// will probably move to stats.[c,h] in the near future though, and it
// can pull in ``genotypes.h`` without issues.
#include <tskit/genotypes.h>

static void
update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A,
const tsk_size_t *restrict offsets, double *D)

{
double tu, tv;
tsk_size_t count = 0;
const tsk_size_t num_alleles = var->num_alleles;
const tsk_id_t n = (tsk_id_t) var->num_samples;

tu = time[u];
tv = time[v];
while (u != v) {
if (tu < tv) {
count += mutations_per_node[u];
u = parent[u];
if (u == TSK_NULL) {
break;
}
tu = time[u];
} else {
count += mutations_per_node[v];
v = parent[v];
if (v == TSK_NULL) {
break;
tsk_size_t a, b, j, k;
tsk_id_t u, v;

for (a = 0; a < num_alleles; a++) {
for (b = a + 1; b < num_alleles; b++) {
for (j = offsets[a]; j < offsets[a + 1]; j++) {
for (k = offsets[b]; k < offsets[b + 1]; k++) {
u = A[j];
v = A[k];
/* Only increment the upper triangle to (hopefully) improve memory
* access patterns */
if (u > v) {
v = A[j];
u = A[k];
}
D[u * n + v]++;
}
}
tv = time[v];
}
}
if (u != v) {
while (u != TSK_NULL) {
count += mutations_per_node[u];
u = parent[u];
}
while (v != TSK_NULL) {
count += mutations_per_node[v];
v = parent[v];
}

static void
group_alleles(const tsk_variant_t *var, tsk_id_t *restrict A, tsk_size_t *offsets)
{
const tsk_size_t n = var->num_samples;
const int32_t *restrict genotypes = var->genotypes;
tsk_id_t a;
tsk_size_t j, k;

k = 0;
offsets[0] = 0;
for (a = 0; a < (tsk_id_t) var->num_alleles; a++) {
offsets[a + 1] = offsets[a];
for (j = 0; j < n; j++) {
if (genotypes[j] == a) {
offsets[a + 1]++;
A[k] = (tsk_id_t) j;
k++;
}
}
}
return count;
}

static int
Expand All @@ -6643,72 +6662,100 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
double *restrict result)
{
int ret = 0;
tsk_tree_t tree;
const tsk_size_t n = num_samples;
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
const double *restrict nodes_time = self->tables->nodes.time;
tsk_size_t i, j, k, tree_site, tree_mut;
tsk_site_t site;
tsk_mutation_t mut;
tsk_id_t u, v;
double left, right, span_left, span_right;
tsk_size_t i;
tsk_id_t site_id;
double left, right;
double *restrict D;
tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node));

ret = tsk_tree_init(&tree, self, 0);
const tsk_id_t num_sites = (tsk_id_t) self->tables->sites.num_rows;
const double *restrict sites_position = self->tables->sites.position;
tsk_id_t *A = tsk_malloc(num_samples * sizeof(*A));
/* Allocate the allele offsets at the first variant */
tsk_size_t max_alleles = 0;
tsk_size_t *allele_offsets = NULL;
tsk_variant_t variant;

ret = tsk_variant_init(
&variant, self, samples, num_samples, NULL, TSK_ISOLATED_NOT_MISSING);
if (ret != 0) {
goto out;
}
if (mutations_per_node == NULL) {

if (A == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

site_id = 0;
while (site_id < num_sites && sites_position[site_id] < windows[0]) {
site_id++;
}

for (i = 0; i < num_windows; i++) {
left = windows[i];
right = windows[i + 1];
D = result + i * n * n;
ret = tsk_tree_seek(&tree, left, 0);
if (ret != 0) {
goto out;
}
while (tree.interval.left < right && tree.index != -1) {
span_left = TSK_MAX(tree.interval.left, left);
span_right = TSK_MIN(tree.interval.right, right);
D = result + i * num_samples * num_samples;

/* NOTE: we could avoid this full memset across all nodes by doing
* the same loops again and decrementing at the end of the main
* tree-loop. It's probably not worth it though, because of the
* overwhelming O(n^2) below */
tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node));
for (tree_site = 0; tree_site < tree.sites_length; tree_site++) {
site = tree.sites[tree_site];
if (span_left <= site.position && site.position < span_right) {
for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) {
mut = site.mutations[tree_mut];
mutations_per_node[mut.node]++;
}
}
if (site_id < num_sites) {
tsk_bug_assert(sites_position[site_id] >= left);
}
while (site_id < num_sites && sites_position[site_id] < right) {
ret = tsk_variant_decode(&variant, site_id, 0);
if (ret != 0) {
goto out;
}

for (j = 0; j < n; j++) {
u = samples[j];
for (k = j + 1; k < n; k++) {
v = samples[k];
D[j * n + k] += (double) count_mutations_on_path(
u, v, tree.parent, nodes_time, mutations_per_node);
if (variant.num_alleles > max_alleles) {
/* could do some kind of doubling here, but there's no
* point - just keep it simple for testing. */
max_alleles = variant.num_alleles;
tsk_safe_free(allele_offsets);
allele_offsets = tsk_malloc((max_alleles + 1) * sizeof(*allele_offsets));
if (allele_offsets == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
}
ret = tsk_tree_next(&tree);
if (ret < 0) {
goto out;
}
group_alleles(&variant, A, allele_offsets);
update_site_divergence(&variant, A, allele_offsets, D);
site_id++;
}
}
ret = 0;
out:
tsk_tree_free(&tree);
tsk_safe_free(mutations_per_node);
tsk_variant_free(&variant);
tsk_safe_free(A);
tsk_safe_free(allele_offsets);
return ret;
}

static int
get_sample_index_map(const tsk_size_t num_nodes, const tsk_size_t num_samples,
const tsk_id_t *restrict samples, tsk_id_t **ret_sample_index_map)
{
int ret = 0;
tsk_size_t j;
tsk_id_t u;
tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map));

if (sample_index_map == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
/* Assign the output pointer here so that it will be freed in the case
* of an error raised in the input checking */
*ret_sample_index_map = sample_index_map;

for (j = 0; j < num_nodes; j++) {
sample_index_map[j] = TSK_NULL;
}
for (j = 0; j < num_samples; j++) {
u = samples[j];
if (sample_index_map[u] != TSK_NULL) {
ret = TSK_ERR_DUPLICATE_SAMPLE;
goto out;
}
sample_index_map[u] = (tsk_id_t) j;
}
out:
return ret;
}

Expand Down Expand Up @@ -6739,9 +6786,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *samples = self->samples;
tsk_size_t n = self->num_samples;
const double default_windows[] = { 0, self->tables->sequence_length };
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
bool stat_site = !!(options & TSK_STAT_SITE);
bool stat_branch = !!(options & TSK_STAT_BRANCH);
bool stat_node = !!(options & TSK_STAT_NODE);
tsk_id_t *sample_index_map = NULL;

if (stat_node) {
ret = TSK_ERR_UNSUPPORTED_STAT_MODE;
Expand Down Expand Up @@ -6785,6 +6834,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
}
}

ret = get_sample_index_map(num_nodes, n, samples, &sample_index_map);
if (ret != 0) {
goto out;
}

tsk_memset(result, 0, num_windows * n * n * sizeof(*result));

if (stat_branch) {
Expand All @@ -6801,5 +6855,6 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
fill_lower_triangle(result, n, num_windows);

out:
tsk_safe_free(sample_index_map);
return ret;
}
Loading

0 comments on commit 86e2960

Please sign in to comment.