Skip to content

Commit

Permalink
Convert LS HMM code to use tree_position_t
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 12, 2023
1 parent a8bc588 commit 158ed1a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 34 deletions.
63 changes: 31 additions & 32 deletions c/tskit/haplotype_matching.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ int
tsk_ls_hmm_free(tsk_ls_hmm_t *self)
{
tsk_tree_free(&self->tree);
tsk_diff_iter_free(&self->diffs);
tsk_tree_position_free(&self->tree_pos);
tsk_safe_free(self->recombination_rate);
tsk_safe_free(self->mutation_rate);
tsk_safe_free(self->recombination_rate);
Expand Down Expand Up @@ -248,9 +248,8 @@ tsk_ls_hmm_reset(tsk_ls_hmm_t *self)
tsk_memset(self->transition_parent, 0xff,
self->max_transitions * sizeof(*self->transition_parent));

/* This is safe because we've already zero'd out the memory. */
tsk_diff_iter_free(&self->diffs);
ret = tsk_diff_iter_init_from_ts(&self->diffs, self->tree_sequence, false);
tsk_tree_position_free(&self->tree_pos);
ret = tsk_tree_position_init(&self->tree_pos, self->tree_sequence, 0);
if (ret != 0) {
goto out;
}
Expand Down Expand Up @@ -306,47 +305,48 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self)
int ret = 0;
tsk_id_t *restrict parent = self->parent;
tsk_id_t *restrict T_index = self->transition_index;
const tsk_id_t *restrict edges_child = self->tree_sequence->tables->edges.child;
const tsk_id_t *restrict edges_parent = self->tree_sequence->tables->edges.parent;
tsk_value_transition_t *restrict T = self->transitions;
tsk_edge_list_node_t *record;
tsk_edge_list_t records_out, records_in;
tsk_edge_t edge;
double left, right;
tsk_id_t u;
tsk_id_t u, c, p, j, e;
tsk_value_transition_t *vt;

ret = tsk_diff_iter_next(&self->diffs, &left, &right, &records_out, &records_in);
if (ret < 0) {
goto out;
}
tsk_tree_position_next(&self->tree_pos);
tsk_bug_assert(self->tree_pos.index != -1);
tsk_bug_assert(self->tree_pos.index == self->tree.index);

for (record = records_out.head; record != NULL; record = record->next) {
u = record->edge.child;
for (j = self->tree_pos.out.start; j < self->tree_pos.out.stop; j++) {
e = self->tree_pos.out.order[j];
c = edges_child[e];
u = c;
if (T_index[u] == TSK_NULL) {
/* Ensure the subtree we're detaching has a transition at the root */
while (T_index[u] == TSK_NULL) {
u = parent[u];
tsk_bug_assert(u != TSK_NULL);
}
tsk_bug_assert(self->num_transitions < self->max_transitions);
T_index[record->edge.child] = (tsk_id_t) self->num_transitions;
T[self->num_transitions].tree_node = record->edge.child;
T_index[c] = (tsk_id_t) self->num_transitions;
T[self->num_transitions].tree_node = c;
T[self->num_transitions].value = T[T_index[u]].value;
self->num_transitions++;
}
parent[record->edge.child] = TSK_NULL;
parent[c] = TSK_NULL;
}

for (record = records_in.head; record != NULL; record = record->next) {
edge = record->edge;
parent[edge.child] = edge.parent;
u = edge.parent;
if (parent[edge.parent] == TSK_NULL) {
for (j = self->tree_pos.in.start; j < self->tree_pos.in.stop; j++) {
e = self->tree_pos.in.order[j];
c = edges_child[e];
p = edges_parent[e];
parent[c] = p;
u = p;
if (parent[p] == TSK_NULL) {
/* Grafting onto a new root. */
if (T_index[record->edge.parent] == TSK_NULL) {
T_index[edge.parent] = (tsk_id_t) self->num_transitions;
if (T_index[p] == TSK_NULL) {
T_index[p] = (tsk_id_t) self->num_transitions;
tsk_bug_assert(self->num_transitions < self->max_transitions);
T[self->num_transitions].tree_node = edge.parent;
T[self->num_transitions].value = T[T_index[edge.child]].value;
T[self->num_transitions].tree_node = p;
T[self->num_transitions].value = T[T_index[c]].value;
self->num_transitions++;
}
} else {
Expand All @@ -356,18 +356,17 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self)
}
tsk_bug_assert(u != TSK_NULL);
}
tsk_bug_assert(T_index[u] != -1 && T_index[edge.child] != -1);
if (T[T_index[u]].value == T[T_index[edge.child]].value) {
vt = &T[T_index[edge.child]];
tsk_bug_assert(T_index[u] != -1 && T_index[c] != -1);
if (T[T_index[u]].value == T[T_index[c]].value) {
vt = &T[T_index[c]];
/* Mark the value transition as unusued */
vt->value = -1;
vt->tree_node = TSK_NULL;
T_index[edge.child] = TSK_NULL;
T_index[c] = TSK_NULL;
}
}

ret = tsk_ls_hmm_remove_dead_roots(self);
out:
return ret;
}

Expand Down
7 changes: 5 additions & 2 deletions c/tskit/haplotype_matching.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2019-2022 Tskit Developers
* Copyright (c) 2019-2023 Tskit Developers
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -98,7 +98,10 @@ typedef struct _tsk_ls_hmm_t {
tsk_size_t num_nodes;
/* state */
tsk_tree_t tree;
tsk_diff_iter_t diffs;
/* NOTE: this tree_position will be redundant once we integrate the top-level
* tree class with this.
*/
tsk_tree_position_t tree_pos;
tsk_id_t *parent;
/* The probability value transitions on the tree */
tsk_value_transition_t *transitions;
Expand Down

0 comments on commit 158ed1a

Please sign in to comment.