diff --git a/c/tests/test_haplotype_matching.c b/c/tests/test_haplotype_matching.c index 7a8bda84dc..943654ff91 100644 --- a/c/tests/test_haplotype_matching.c +++ b/c/tests/test_haplotype_matching.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,46 +28,6 @@ #include #include -/**************************************************************** - * TestHMM - ****************************************************************/ - -static double -tsk_ls_hmm_compute_normalisation_factor_site_test(tsk_ls_hmm_t *TSK_UNUSED(self)) -{ - return 1.0; -} - -static int -tsk_ls_hmm_next_probability_test(tsk_ls_hmm_t *TSK_UNUSED(self), - tsk_id_t TSK_UNUSED(site_id), double TSK_UNUSED(p_last), bool TSK_UNUSED(is_match), - tsk_id_t TSK_UNUSED(node), double *result) -{ - *result = rand(); - /* printf("next proba = %f\n", *result); */ - return 0; -} - -static int -run_test_hmm(tsk_ls_hmm_t *hmm, int32_t *haplotype, tsk_compressed_matrix_t *output) -{ - int ret = 0; - - srand(1); - - ret = tsk_ls_hmm_run(hmm, haplotype, tsk_ls_hmm_next_probability_test, - tsk_ls_hmm_compute_normalisation_factor_site_test, output); - if (ret != 0) { - goto out; - } -out: - return ret; -} - -/**************************************************************** - * TestHMM - ****************************************************************/ - static void test_single_tree_missing_alleles(void) { @@ -206,6 +166,7 @@ test_single_tree_match_impossible(void) tsk_treeseq_t ts; tsk_ls_hmm_t ls_hmm; tsk_compressed_matrix_t forward; + tsk_compressed_matrix_t backward; tsk_viterbi_matrix_t viterbi; double rho[] = { 0.0, 0.25, 0.25 }; @@ -228,8 +189,16 @@ test_single_tree_match_impossible(void) tsk_viterbi_matrix_print_state(&viterbi, _devnull); tsk_ls_hmm_print_state(&ls_hmm, _devnull); + ret = tsk_ls_hmm_backward(&ls_hmm, h, forward.normalisation_factor, &backward, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MATCH_IMPOSSIBLE); + tsk_compressed_matrix_print_state(&backward, _devnull); + /* tsk_compressed_matrix_print_state(&forward, stdout); */ + /* tsk_compressed_matrix_print_state(&backward, stdout); */ + tsk_ls_hmm_print_state(&ls_hmm, _devnull); + tsk_ls_hmm_free(&ls_hmm); tsk_compressed_matrix_free(&forward); + tsk_compressed_matrix_free(&backward); tsk_viterbi_matrix_free(&viterbi); tsk_treeseq_free(&ts); } @@ -275,12 +244,15 @@ test_single_tree_errors(void) ret = tsk_compressed_matrix_store_site(&forward, 4, 0, 0, NULL); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); - T[0].tree_node = -1; - T[0].value = 0; - ret = tsk_compressed_matrix_store_site(&forward, 0, 1, 1, T); - CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_compressed_matrix_decode(&forward, (double *) decoded); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + /* FIXME disabling this tests for now because we filter out negative + * nodes when storing now, to accomodate some oddness in the initial + * conditions of the backward matrix. */ + /* T[0].tree_node = -1; */ + /* T[0].value = 0; */ + /* ret = tsk_compressed_matrix_store_site(&forward, 0, 1, 1, T); */ + /* CU_ASSERT_EQUAL_FATAL(ret, 0); */ + /* ret = tsk_compressed_matrix_decode(&forward, (double *) decoded); */ + /* CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); */ T[0].tree_node = 7; T[0].value = 0; @@ -443,7 +415,7 @@ test_multi_tree_exact_match(void) int ret = 0; tsk_treeseq_t ts; tsk_ls_hmm_t ls_hmm; - tsk_compressed_matrix_t forward; + tsk_compressed_matrix_t forward, backward; tsk_viterbi_matrix_t viterbi; double rho[] = { 0.0, 0.25, 0.25 }; @@ -465,6 +437,13 @@ test_multi_tree_exact_match(void) ret = tsk_compressed_matrix_decode(&forward, decoded_compressed_matrix); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_ls_hmm_backward(&ls_hmm, h, forward.normalisation_factor, &backward, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_ls_hmm_print_state(&ls_hmm, _devnull); + tsk_compressed_matrix_print_state(&backward, _devnull); + ret = tsk_compressed_matrix_decode(&backward, decoded_compressed_matrix); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_ls_hmm_viterbi(&ls_hmm, h, &viterbi, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_viterbi_matrix_print_state(&viterbi, _devnull); @@ -492,6 +471,7 @@ test_multi_tree_exact_match(void) tsk_ls_hmm_free(&ls_hmm); tsk_compressed_matrix_free(&forward); + tsk_compressed_matrix_free(&backward); tsk_viterbi_matrix_free(&viterbi); tsk_treeseq_free(&ts); } @@ -529,7 +509,8 @@ test_caterpillar_tree_many_values(void) int ret = 0; tsk_ls_hmm_t ls_hmm; tsk_compressed_matrix_t matrix; - double unused[] = { 0, 0, 0, 0, 0 }; + double rho[] = { 0.1, 0.1, 0.1, 0.1, 0.1 }; + double mu[] = { 0.0, 0.0, 0.0, 0.0, 0.0 }; int32_t h[] = { 0, 0, 0, 0, 0 }; tsk_size_t n[] = { 8, @@ -542,11 +523,11 @@ test_caterpillar_tree_many_values(void) for (j = 0; j < sizeof(n) / sizeof(*n); j++) { ts = caterpillar_tree(n[j], 5, n[j] - 2); - ret = tsk_ls_hmm_init(&ls_hmm, ts, unused, unused, 0); + ret = tsk_ls_hmm_init(&ls_hmm, ts, rho, mu, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_compressed_matrix_init(&matrix, ts, 1 << 10, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = run_test_hmm(&ls_hmm, h, &matrix); + ret = tsk_ls_hmm_forward(&ls_hmm, h, &matrix, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_compressed_matrix_print_state(&matrix, _devnull); tsk_ls_hmm_print_state(&ls_hmm, _devnull); @@ -559,13 +540,13 @@ test_caterpillar_tree_many_values(void) j = 40; ts = caterpillar_tree(j, 5, j - 2); - ret = tsk_ls_hmm_init(&ls_hmm, ts, unused, unused, 0); + ret = tsk_ls_hmm_init(&ls_hmm, ts, rho, mu, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_compressed_matrix_init(&matrix, ts, 1 << 20, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - /* Short circuit this value so we can run the test in reasonable time */ - ls_hmm.max_parsimony_words = 1; - ret = run_test_hmm(&ls_hmm, h, &matrix); + /* Short circuit this value so we can run the test */ + ls_hmm.max_parsimony_words = 0; + ret = tsk_ls_hmm_forward(&ls_hmm, h, &matrix, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TOO_MANY_VALUES); tsk_ls_hmm_free(&ls_hmm); diff --git a/c/tskit/haplotype_matching.c b/c/tskit/haplotype_matching.c index d6fdfd7f46..ea8853e72e 100644 --- a/c/tskit/haplotype_matching.c +++ b/c/tskit/haplotype_matching.c @@ -230,10 +230,9 @@ tsk_ls_hmm_free(tsk_ls_hmm_t *self) } static int -tsk_ls_hmm_reset(tsk_ls_hmm_t *self) +tsk_ls_hmm_reset(tsk_ls_hmm_t *self, double value) { int ret = 0; - double n = (double) self->num_samples; tsk_size_t j; tsk_id_t u; const tsk_id_t *samples; @@ -257,7 +256,7 @@ tsk_ls_hmm_reset(tsk_ls_hmm_t *self) for (j = 0; j < self->num_samples; j++) { u = samples[j]; self->transitions[j].tree_node = u; - self->transitions[j].value = 1.0 / n; + self->transitions[j].value = value; self->transition_index[u] = (tsk_id_t) j; } self->num_transitions = self->num_samples; @@ -300,7 +299,7 @@ tsk_ls_hmm_remove_dead_roots(tsk_ls_hmm_t *self) } static int -tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) +tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self, int direction) { int ret = 0; tsk_id_t *restrict parent = self->parent; @@ -311,11 +310,15 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) tsk_id_t u, c, p, j, e; tsk_value_transition_t *vt; - tsk_tree_position_next(&self->tree_pos); + if (direction == TSK_DIR_FORWARD) { + tsk_tree_position_next(&self->tree_pos); + } else { + tsk_tree_position_prev(&self->tree_pos); + } tsk_bug_assert(self->tree_pos.index != -1); tsk_bug_assert(self->tree_pos.index == self->tree.index); - for (j = self->tree_pos.out.start; j < self->tree_pos.out.stop; j++) { + for (j = self->tree_pos.out.start; j != self->tree_pos.out.stop; j += direction) { e = self->tree_pos.out.order[j]; c = edges_child[e]; u = c; @@ -334,7 +337,7 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) parent[c] = TSK_NULL; } - for (j = self->tree_pos.in.start; j < self->tree_pos.in.stop; j++) { + for (j = self->tree_pos.in.start; j != self->tree_pos.in.stop; j += direction) { e = self->tree_pos.in.order[j]; c = edges_child[e]; p = edges_parent[e]; @@ -920,7 +923,7 @@ tsk_ls_hmm_compress(tsk_ls_hmm_t *self) } static int -tsk_ls_hmm_process_site( +tsk_ls_hmm_process_site_forward( tsk_ls_hmm_t *self, const tsk_site_t *site, int32_t haplotype_state) { int ret = 0; @@ -959,28 +962,23 @@ tsk_ls_hmm_process_site( return ret; } -int -tsk_ls_hmm_run(tsk_ls_hmm_t *self, int32_t *haplotype, - int (*next_probability)(tsk_ls_hmm_t *, tsk_id_t, double, bool, tsk_id_t, double *), - double (*compute_normalisation_factor)(struct _tsk_ls_hmm_t *), void *output) +static int +tsk_ls_hmm_run_forward(tsk_ls_hmm_t *self, int32_t *haplotype) { int ret = 0; int t_ret; const tsk_site_t *sites; tsk_size_t j, num_sites; + const double n = (double) self->num_samples; - self->next_probability = next_probability; - self->compute_normalisation_factor = compute_normalisation_factor; - self->output = output; - - ret = tsk_ls_hmm_reset(self); + ret = tsk_ls_hmm_reset(self, 1 / n); if (ret != 0) { goto out; } for (t_ret = tsk_tree_first(&self->tree); t_ret == TSK_TREE_OK; t_ret = tsk_tree_next(&self->tree)) { - ret = tsk_ls_hmm_update_tree(self); + ret = tsk_ls_hmm_update_tree(self, TSK_DIR_FORWARD); if (ret != 0) { goto out; } @@ -990,7 +988,8 @@ tsk_ls_hmm_run(tsk_ls_hmm_t *self, int32_t *haplotype, goto out; } for (j = 0; j < num_sites; j++) { - ret = tsk_ls_hmm_process_site(self, &sites[j], haplotype[sites[j].id]); + ret = tsk_ls_hmm_process_site_forward( + self, &sites[j], haplotype[sites[j].id]); if (ret != 0) { goto out; } @@ -1080,11 +1079,168 @@ tsk_ls_hmm_forward(tsk_ls_hmm_t *self, int32_t *haplotype, goto out; } } - ret = tsk_ls_hmm_run(self, haplotype, tsk_ls_hmm_next_probability_forward, - tsk_ls_hmm_compute_normalisation_factor_forward, output); + + self->next_probability = tsk_ls_hmm_next_probability_forward; + self->compute_normalisation_factor = tsk_ls_hmm_compute_normalisation_factor_forward; + self->output = output; + + ret = tsk_ls_hmm_run_forward(self, haplotype); +out: + return ret; +} + +/**************************************************************** + * Backward Algorithm + ****************************************************************/ + +static int +tsk_ls_hmm_next_probability_backward(tsk_ls_hmm_t *self, tsk_id_t site_id, double p_last, + bool is_match, tsk_id_t TSK_UNUSED(node), double *result) +{ + const double mu = self->mutation_rate[site_id]; + const double num_alleles = self->num_alleles[site_id]; + double p_e; + + p_e = mu; + if (is_match) { + p_e = 1 - (num_alleles - 1) * mu; + } + *result = p_last * p_e; + return 0; +} + +static int +tsk_ls_hmm_process_site_backward(tsk_ls_hmm_t *self, const tsk_site_t *site, + const int32_t haplotype_state, const double normalisation_factor) +{ + int ret = 0; + double x, b_last_sum; + tsk_compressed_matrix_t *output = (tsk_compressed_matrix_t *) self->output; + tsk_value_transition_t *restrict T = self->transitions; + const unsigned int precision = (unsigned int) self->precision; + const double rho = self->recombination_rate[site->id]; + const double n = (double) self->num_samples; + tsk_size_t j; + + /* FIXME!!! We are calling compress twice here because we need to compress + * immediately before calling store_site in order to filter out -1 nodes, + * and also (crucially) to ensure that the value transitions are listed + * in preorder, which we rely on later for decoding. + */ + ret = tsk_ls_hmm_compress(self); + if (ret != 0) { + goto out; + } + ret = tsk_compressed_matrix_store_site( + output, site->id, normalisation_factor, (tsk_size_t) self->num_transitions, T); + if (ret != 0) { + goto out; + } + + ret = tsk_ls_hmm_update_probabilities(self, site, haplotype_state); + if (ret != 0) { + goto out; + } + /* DO WE NEED THIS compress?? See above */ + ret = tsk_ls_hmm_compress(self); + if (ret != 0) { + goto out; + } + tsk_bug_assert(self->num_transitions <= self->num_samples); + b_last_sum = self->compute_normalisation_factor(self); + for (j = 0; j < self->num_transitions; j++) { + tsk_bug_assert(T[j].tree_node != TSK_NULL); + x = rho * b_last_sum / n + (1 - rho) * T[j].value; + x /= normalisation_factor; + T[j].value = tsk_round(x, precision); + } +out: + return ret; +} + +static int +tsk_ls_hmm_run_backward( + tsk_ls_hmm_t *self, int32_t *haplotype, const double *forward_norm) +{ + int ret = 0; + int t_ret; + const tsk_site_t *sites; + double s; + tsk_size_t num_sites; + tsk_id_t j; + + ret = tsk_ls_hmm_reset(self, 1); if (ret != 0) { goto out; } + + for (t_ret = tsk_tree_last(&self->tree); t_ret == TSK_TREE_OK; + t_ret = tsk_tree_prev(&self->tree)) { + ret = tsk_ls_hmm_update_tree(self, TSK_DIR_REVERSE); + if (ret != 0) { + goto out; + } + /* tsk_ls_hmm_check_state(self); */ + ret = tsk_tree_get_sites(&self->tree, &sites, &num_sites); + if (ret != 0) { + goto out; + } + for (j = (tsk_id_t) num_sites - 1; j >= 0; j--) { + s = forward_norm[sites[j].id]; + if (s <= 0) { + /* NOTE: I'm not sure if this is the correct interpretation, + * but norm values of 0 do lead to problems, and this seems + * like a simple way of guarding against it. We do seem to + * get norm values of 0 with impossible matches from the fwd + * matrix. + */ + ret = TSK_ERR_MATCH_IMPOSSIBLE; + goto out; + } + ret = tsk_ls_hmm_process_site_backward( + self, &sites[j], haplotype[sites[j].id], s); + if (ret != 0) { + goto out; + } + } + } + /* Set to zero so we can print and check the state OK. */ + self->num_transitions = 0; + if (t_ret != 0) { + ret = t_ret; + goto out; + } +out: + return ret; +} + +int +tsk_ls_hmm_backward(tsk_ls_hmm_t *self, int32_t *haplotype, const double *forward_norm, + tsk_compressed_matrix_t *output, tsk_flags_t options) +{ + int ret = 0; + + if (!(options & TSK_NO_INIT)) { + ret = tsk_compressed_matrix_init(output, self->tree_sequence, 0, 0); + if (ret != 0) { + goto out; + } + } else { + if (output->tree_sequence != self->tree_sequence) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_compressed_matrix_clear(output); + if (ret != 0) { + goto out; + } + } + + self->next_probability = tsk_ls_hmm_next_probability_backward; + self->compute_normalisation_factor = tsk_ls_hmm_compute_normalisation_factor_forward; + self->output = output; + + ret = tsk_ls_hmm_run_backward(self, haplotype, forward_norm); out: return ret; } @@ -1162,11 +1318,12 @@ tsk_ls_hmm_viterbi(tsk_ls_hmm_t *self, int32_t *haplotype, tsk_viterbi_matrix_t goto out; } } - ret = tsk_ls_hmm_run(self, haplotype, tsk_ls_hmm_next_probability_viterbi, - tsk_ls_hmm_compute_normalisation_factor_viterbi, output); - if (ret != 0) { - goto out; - } + + self->next_probability = tsk_ls_hmm_next_probability_viterbi; + self->compute_normalisation_factor = tsk_ls_hmm_compute_normalisation_factor_viterbi; + self->output = output; + + ret = tsk_ls_hmm_run_forward(self, haplotype); out: return ret; } @@ -1279,9 +1436,11 @@ tsk_compressed_matrix_store_site(tsk_compressed_matrix_t *self, tsk_id_t site, } for (j = 0; j < num_transitions; j++) { + tsk_bug_assert(transitions[j].tree_node >= 0); self->values[site][j] = transitions[j].value; self->nodes[site][j] = transitions[j].tree_node; } + out: return ret; } diff --git a/c/tskit/haplotype_matching.h b/c/tskit/haplotype_matching.h index e4d82bef03..4809939458 100644 --- a/c/tskit/haplotype_matching.h +++ b/c/tskit/haplotype_matching.h @@ -134,6 +134,7 @@ typedef struct _tsk_ls_hmm_t { void *output; } tsk_ls_hmm_t; +/* TODO constify these APIs */ int tsk_ls_hmm_init(tsk_ls_hmm_t *self, tsk_treeseq_t *tree_sequence, double *recombination_rate, double *mutation_rate, tsk_flags_t options); int tsk_ls_hmm_set_precision(tsk_ls_hmm_t *self, unsigned int precision); @@ -141,11 +142,10 @@ int tsk_ls_hmm_free(tsk_ls_hmm_t *self); void tsk_ls_hmm_print_state(tsk_ls_hmm_t *self, FILE *out); int tsk_ls_hmm_forward(tsk_ls_hmm_t *self, int32_t *haplotype, tsk_compressed_matrix_t *output, tsk_flags_t options); +int tsk_ls_hmm_backward(tsk_ls_hmm_t *self, int32_t *haplotype, + const double *forward_norm, tsk_compressed_matrix_t *output, tsk_flags_t options); int tsk_ls_hmm_viterbi(tsk_ls_hmm_t *self, int32_t *haplotype, tsk_viterbi_matrix_t *output, tsk_flags_t options); -int tsk_ls_hmm_run(tsk_ls_hmm_t *self, int32_t *haplotype, - int (*next_probability)(tsk_ls_hmm_t *, tsk_id_t, double, bool, tsk_id_t, double *), - double (*compute_normalisation_factor)(struct _tsk_ls_hmm_t *), void *output); int tsk_compressed_matrix_init(tsk_compressed_matrix_t *self, tsk_treeseq_t *tree_sequence, tsk_size_t block_size, tsk_flags_t options); diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 5c6bd29986..6de1071f26 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -13260,6 +13260,64 @@ LsHmm_forward_matrix(LsHmm *self, PyObject *args) return ret; } +static PyObject * +LsHmm_backward_matrix(LsHmm *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + PyObject *haplotype = NULL; + PyObject *forward_norm = NULL; + CompressedMatrix *compressed_matrix = NULL; + PyArrayObject *haplotype_array = NULL; + PyArrayObject *forward_norm_array = NULL; + npy_intp *shape, num_sites; + + if (LsHmm_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "OOO!", &haplotype, &forward_norm, &CompressedMatrixType, + &compressed_matrix)) { + goto out; + } + num_sites = (npy_intp) tsk_treeseq_get_num_sites(self->tree_sequence->tree_sequence); + + haplotype_array = (PyArrayObject *) PyArray_FROMANY( + haplotype, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (haplotype_array == NULL) { + goto out; + } + shape = PyArray_DIMS(haplotype_array); + if (shape[0] != num_sites) { + PyErr_SetString( + PyExc_ValueError, "haplotype array must have dimension (num_sites,)"); + goto out; + } + + forward_norm_array = (PyArrayObject *) PyArray_FROMANY( + forward_norm, NPY_FLOAT64, 1, 1, NPY_ARRAY_IN_ARRAY); + if (forward_norm_array == NULL) { + goto out; + } + shape = PyArray_DIMS(forward_norm_array); + if (shape[0] != num_sites) { + PyErr_SetString( + PyExc_ValueError, "forward_norm array must have dimension (num_sites,)"); + goto out; + } + err = tsk_ls_hmm_backward(self->ls_hmm, PyArray_DATA(haplotype_array), + PyArray_DATA(forward_norm_array), compressed_matrix->compressed_matrix, + TSK_NO_INIT); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + Py_XDECREF(haplotype_array); + Py_XDECREF(forward_norm_array); + return ret; +} + static PyObject * LsHmm_viterbi_matrix(LsHmm *self, PyObject *args) { @@ -13306,6 +13364,10 @@ static PyMethodDef LsHmm_methods[] = { .ml_meth = (PyCFunction) LsHmm_forward_matrix, .ml_flags = METH_VARARGS, .ml_doc = "Returns the tree encoded forward matrix for a given haplotype" }, + { .ml_name = "backward_matrix", + .ml_meth = (PyCFunction) LsHmm_backward_matrix, + .ml_flags = METH_VARARGS, + .ml_doc = "Returns the tree encoded backward matrix for a given haplotype" }, { .ml_name = "viterbi_matrix", .ml_meth = (PyCFunction) LsHmm_viterbi_matrix, .ml_flags = METH_VARARGS, diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index b09ebcc005..725c6bf38c 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -22,13 +22,17 @@ """ Python implementation of the Li and Stephens forwards and backwards algorithms. """ -import itertools +import warnings import lshmm as ls import msprime import numpy as np +import numpy.testing as nt +import pytest +import _tskit import tskit +from tests import tsutil MISSING = -1 @@ -109,8 +113,8 @@ def __init__( self.N = np.zeros(ts.num_nodes, dtype=int) # Efficiently compute the allelic state at a site self.allelic_state = np.zeros(ts.num_nodes, dtype=int) - 1 - # Diffs so we can can update T and T_index between trees. - self.edge_diffs = self.ts.edge_diffs() + # TreePosition so we can can update T and T_index between trees. + self.tree_pos = tsutil.TreePosition(ts) self.parent = np.zeros(self.ts.num_nodes, dtype=int) - 1 self.tree = tskit.Tree(self.ts) self.output = None @@ -229,16 +233,24 @@ def compute(u, parent_state): if T_parent[j] != -1: self.N[T_parent[j]] -= self.N[j] - def update_tree(self): + def update_tree(self, direction=tskit.FORWARD): """ Update the internal data structures to move on to the next tree. """ parent = self.parent T_index = self.T_index T = self.T - _, edges_out, edges_in = next(self.edge_diffs) - - for edge in edges_out: + if direction == tskit.FORWARD: + self.tree_pos.next() + else: + self.tree_pos.prev() + assert self.tree_pos.index == self.tree.index + + for j in range( + self.tree_pos.out_range.start, self.tree_pos.out_range.stop, direction + ): + e = self.tree_pos.out_range.order[j] + edge = self.ts.edge(e) u = edge.child if T_index[u] == -1: # Make sure the subtree we're detaching has an T_index-value at the root. @@ -251,7 +263,11 @@ def update_tree(self): ) parent[edge.child] = -1 - for edge in edges_in: + for j in range( + self.tree_pos.in_range.start, self.tree_pos.in_range.stop, direction + ): + e = self.tree_pos.in_range.order[j] + edge = self.ts.edge(e) parent[edge.child] = edge.parent u = edge.parent if parent[edge.parent] == -1: @@ -320,6 +336,7 @@ def update_probabilities(self, site, haplotype_state): match = ( haplotype_state == MISSING or haplotype_state == allelic_state[v] ) + # Note that the node u is used only by Viterbi st.value = self.compute_next_probability(site.id, st.value, match, u) # Unset the states @@ -327,59 +344,61 @@ def update_probabilities(self, site, haplotype_state): for mutation in site.mutations: allelic_state[mutation.node] = -1 - def process_site(self, site, haplotype_state, forwards=True): - if forwards: - # Forwards algorithm, or forwards pass in Viterbi - self.update_probabilities(site, haplotype_state) - self.compress() - s = self.compute_normalisation_factor() - for st in self.T: - if st.tree_node != tskit.NULL: - st.value /= s - st.value = round(st.value, self.precision) - self.output.store_site( - site.id, s, [(st.tree_node, st.value) for st in self.T] - ) + def process_site(self, site, haplotype_state): + self.update_probabilities(site, haplotype_state) + self.compress() + s = self.compute_normalisation_factor() + for st in self.T: + assert st.tree_node != tskit.NULL + # if st.tree_node != tskit.NULL: + st.value /= s + st.value = round(st.value, self.precision) + self.output.store_site(site.id, s, [(st.tree_node, st.value) for st in self.T]) + + def compute_emission_proba(self, site_id, is_match): + mu = self.mu[site_id] + n_alleles = self.n_alleles[site_id] + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site else: - # Backwards algorithm - self.output.store_site( - site.id, - self.output.normalisation_factor[site.id], - [(st.tree_node, st.value) for st in self.T], - ) - self.update_probabilities(site, haplotype_state) - self.compress() - b_last_sum = self.compute_normalisation_factor() - s = self.output.normalisation_factor[site.id] - for st in self.T: - if st.tree_node != tskit.NULL: - st.value = ( - self.rho[site.id] / self.ts.num_samples - ) * b_last_sum + (1 - self.rho[site.id]) * st.value - st.value /= s - st.value = round(st.value, self.precision) - - def run_forward(self, h): - n = self.ts.num_samples - self.tree.clear() - for u in self.ts.samples(): - self.T_index[u] = len(self.T) - self.T.append(ValueTransition(tree_node=u, value=1 / n)) - while self.tree.next(): - self.update_tree() - for site in self.tree.sites(): - self.process_site(site, h[site.id]) - return self.output + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + return p_e - def run_backward(self, h): + def initialise(self, value): self.tree.clear() for u in self.ts.samples(): - self.T_index[u] = len(self.T) - self.T.append(ValueTransition(tree_node=u, value=1)) + j = len(self.T) + self.T_index[u] = j + self.T.append(ValueTransition(tree_node=u, value=value)) + + def run(self, h): + n = self.ts.num_samples + self.initialise(1 / n) while self.tree.next(): self.update_tree() for site in self.tree.sites(): - self.process_site(site, h[site.id], forwards=False) + self.process_site(site, h[site.id]) return self.output def compute_normalisation_factor(self): @@ -389,6 +408,146 @@ def compute_next_probability(self, site_id, p_last, is_match, node): raise NotImplementedError() +class ForwardAlgorithm(LsHmmAlgorithm): + """ + The Li and Stephens forward algorithm. + """ + + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = CompressedMatrix(ts) + + def compute_normalisation_factor(self): + s = 0 + for j, st in enumerate(self.T): + assert st.tree_node != tskit.NULL + # assert self.N[j] > 0 + s += self.N[j] * st.value + return s + + def compute_next_probability(self, site_id, p_last, is_match, node): + rho = self.rho[site_id] + n = self.ts.num_samples + p_e = self.compute_emission_proba(site_id, is_match) + p_t = p_last * (1 - rho) + rho / n + return p_t * p_e + + +class BackwardAlgorithm(ForwardAlgorithm): + """ + The Li and Stephens backward algorithm. + """ + + def compute_next_probability(self, site_id, p_next, is_match, node): + p_e = self.compute_emission_proba(site_id, is_match) + return p_next * p_e + + def process_site(self, site, haplotype_state, s): + # FIXME see nodes in the C code for why we have two calls to + # compress + self.compress() + self.output.store_site( + site.id, + s, + [(st.tree_node, st.value) for st in self.T], + ) + self.update_probabilities(site, haplotype_state) + # FIXME see nodes in the C code for why we have two calls to + # compress + self.compress() + b_last_sum = self.compute_normalisation_factor() + n = self.ts.num_samples + rho = self.rho[site.id] + for st in self.T: + if st.tree_node != tskit.NULL: + st.value = rho * b_last_sum / n + (1 - rho) * st.value + st.value /= s + st.value = round(st.value, self.precision) + + def run(self, h, normalisation_factor): + self.initialise(value=1) + while self.tree.prev(): + self.update_tree(direction=tskit.REVERSE) + for site in reversed(list(self.tree.sites())): + self.process_site(site, h[site.id], normalisation_factor[site.id]) + return self.output + + +class ViterbiAlgorithm(LsHmmAlgorithm): + """ + Runs the Li and Stephens Viterbi algorithm. + """ + + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = ViterbiMatrix(ts) + + def compute_normalisation_factor(self): + max_st = ValueTransition(value=-1) + for st in self.T: + assert st.tree_node != tskit.NULL + if st.value > max_st.value: + max_st = st + if max_st.value == 0: + raise ValueError( + "Trying to match non-existent allele with zero mutation rate" + ) + return max_st.value + + def compute_next_probability(self, site_id, p_last, is_match, node): + rho = self.rho[site_id] + n = self.ts.num_samples + + p_no_recomb = p_last * (1 - rho + rho / n) + p_recomb = rho / n + recombination_required = False + if p_no_recomb > p_recomb: + p_t = p_no_recomb + else: + p_t = p_recomb + recombination_required = True + self.output.add_recombination_required(site_id, node, recombination_required) + + p_e = self.compute_emission_proba(site_id, is_match) + return p_t * p_e + + +def assert_compressed_matrices_equal(cm1, cm2): + nt.assert_array_almost_equal(cm1.normalisation_factor, cm2.normalisation_factor) + + for j in range(cm1.num_sites): + site1 = cm1.get_site(j) + site2 = cm2.get_site(j) + assert len(site1) == len(site2) + site1 = dict(site1) + site2 = dict(site2) + + assert set(site1.keys()) == set(site2.keys()) + for node in site1.keys(): + # TODO the precision value should be used as a parameter here + nt.assert_allclose(site1[node], site2[node], rtol=1e-5, atol=1e-8) + + class CompressedMatrix: """ Class representing a num_samples x num_sites matrix compressed by a @@ -398,18 +557,15 @@ class CompressedMatrix: values are on the path). """ - def __init__(self, ts, normalisation_factor=None): + def __init__(self, ts): self.ts = ts self.num_sites = ts.num_sites self.num_samples = ts.num_samples self.value_transitions = [None for _ in range(self.num_sites)] - if normalisation_factor is None: - self.normalisation_factor = np.zeros(self.num_sites) - else: - self.normalisation_factor = normalisation_factor - assert len(self.normalisation_factor) == self.num_sites + self.normalisation_factor = np.zeros(self.num_sites) def store_site(self, site, normalisation_factor, value_transitions): + assert all(u >= 0 for u, _ in value_transitions) self.normalisation_factor[site] = normalisation_factor self.value_transitions[site] = value_transitions @@ -428,25 +584,18 @@ def decode(self): Decodes the tree encoding of the values into an explicit matrix. """ + sample_index_map = np.zeros(self.ts.num_nodes, dtype=int) - 1 + sample_index_map[self.ts.samples()] = np.arange(self.ts.num_samples) A = np.zeros((self.num_sites, self.num_samples)) for tree in self.ts.trees(): for site in tree.sites(): - f = dict(self.value_transitions[site.id]) - for j, u in enumerate(self.ts.samples()): - while u not in f: - u = tree.parent(u) - A[site.id, j] = f[u] + for node, value in self.value_transitions[site.id]: + for u in tree.samples(node): + j = sample_index_map[u] + A[site.id, j] = value return A -class ForwardMatrix(CompressedMatrix): - """Class representing a compressed forward matrix.""" - - -class BackwardMatrix(CompressedMatrix): - """Class representing a compressed backward matrix.""" - - class ViterbiMatrix(CompressedMatrix): """ Class representing the compressed Viterbi matrix. @@ -527,206 +676,7 @@ def traceback(self): return match -class ForwardAlgorithm(LsHmmAlgorithm): - """Runs the Li and Stephens forward algorithm.""" - - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) - self.output = ForwardMatrix(ts) - - def compute_normalisation_factor(self): - s = 0 - for j, st in enumerate(self.T): - assert st.tree_node != tskit.NULL - assert self.N[j] > 0 - s += self.N[j] * st.value - return s - - def compute_next_probability( - self, site_id, p_last, is_match, node - ): # Note node only used in Viterbi - rho = self.rho[site_id] - mu = self.mu[site_id] - n = self.ts.num_samples - n_alleles = self.n_alleles[site_id] - - if self.scale_mutation_based_on_n_alleles: - if is_match: - # Scale mutation based on the number of alleles - # - so the mutation rate is the mutation rate to one of the - # alleles. The overall mutation rate is then - # (n_alleles - 1) * mutation_rate. - p_e = 1 - (n_alleles - 1) * mu - else: - p_e = mu - mu * (n_alleles == 1) - # Added boolean in case we're at an invariant site - else: - # No scaling based on the number of alleles - # - so the mutation rate is the mutation rate to anything. - # This means that we must rescale the mutation rate to a different - # allele, by the number of alleles. - if n_alleles == 1: # In case we're at an invariant site - if is_match: - p_e = 1 - else: - p_e = 0 - else: - if is_match: - p_e = 1 - mu - else: - p_e = mu / (n_alleles - 1) - - p_t = p_last * (1 - rho) + rho / n - return p_t * p_e - - -class BackwardAlgorithm(LsHmmAlgorithm): - """Runs the Li and Stephens backward algorithm.""" - - def __init__( - self, - ts, - rho, - mu, - alleles, - n_alleles, - normalisation_factor, - scale_mutation=False, - precision=10, - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) - self.output = BackwardMatrix(ts, normalisation_factor) - - def compute_normalisation_factor(self): - s = 0 - for j, st in enumerate(self.T): - assert st.tree_node != tskit.NULL - assert self.N[j] > 0 - s += self.N[j] * st.value - return s - - def compute_next_probability( - self, site_id, p_next, is_match, node - ): # Note node only used in Viterbi - mu = self.mu[site_id] - n_alleles = self.n_alleles[site_id] - - if self.scale_mutation_based_on_n_alleles: - if is_match: - p_e = 1 - (n_alleles - 1) * mu - else: - p_e = mu - mu * (n_alleles == 1) - else: - if n_alleles == 1: - if is_match: - p_e = 1 - else: - p_e = 0 - else: - if is_match: - p_e = 1 - mu - else: - p_e = mu / (n_alleles - 1) - return p_next * p_e - - -class ViterbiAlgorithm(LsHmmAlgorithm): - """ - Runs the Li and Stephens Viterbi algorithm. - """ - - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) - self.output = ViterbiMatrix(ts) - - def compute_normalisation_factor(self): - max_st = ValueTransition(value=-1) - for st in self.T: - assert st.tree_node != tskit.NULL - if st.value > max_st.value: - max_st = st - if max_st.value == 0: - raise ValueError( - "Trying to match non-existent allele with zero mutation rate" - ) - return max_st.value - - def compute_next_probability(self, site_id, p_last, is_match, node): - rho = self.rho[site_id] - mu = self.mu[site_id] - n = self.ts.num_samples - n_alleles = self.n_alleles[site_id] - - p_no_recomb = p_last * (1 - rho + rho / n) - p_recomb = rho / n - recombination_required = False - if p_no_recomb > p_recomb: - p_t = p_no_recomb - else: - p_t = p_recomb - recombination_required = True - self.output.add_recombination_required(site_id, node, recombination_required) - - if self.scale_mutation_based_on_n_alleles: - if is_match: - # Scale mutation based on the number of alleles - # - so the mutation rate is the mutation rate to one of the - # alleles. The overall mutation rate is then - # (n_alleles - 1) * mutation_rate. - p_e = 1 - (n_alleles - 1) * mu - else: - p_e = mu - mu * (n_alleles == 1) - # Added boolean in case we're at an invariant site - else: - # No scaling based on the number of alleles - # - so the mutation rate is the mutation rate to anything. - # This means that we must rescale the mutation rate to a different - # allele, by the number of alleles. - if n_alleles == 1: # In case we're at an invariant site - if is_match: - p_e = 1 - else: - p_e = 0 - else: - if is_match: - p_e = 1 - mu - else: - p_e = mu / (n_alleles - 1) - - return p_t * p_e - - -def ls_forward_tree( - h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False -): +def get_site_alleles(ts, h, alleles): if alleles is None: n_alleles = np.int8( [ @@ -746,8 +696,13 @@ def ls_forward_tree( alleles = [alleles for _ in range(ts.num_sites)] else: alleles, n_alleles = check_alleles(alleles, ts.num_sites) + return alleles, n_alleles + - """Forward matrix computation based on a tree sequence.""" +def ls_forward_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + alleles, n_alleles = get_site_alleles(ts, h, alleles) fa = ForwardAlgorithm( ts, rho, @@ -757,70 +712,26 @@ def ls_forward_tree( precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, ) - return fa.run_forward(h) + return fa.run(h) -def ls_backward_tree( - h, ts_mirror, rho, mu, normalisation_factor, precision=30, alleles=None -): - if alleles is None: - n_alleles = np.int8( - [ - len(np.unique(np.append(ts_mirror.genotype_matrix()[j, :], h[j]))) - for j in range(ts_mirror.num_sites) - ] - ) - alleles = tskit.ALLELES_ACGT - if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: - alleles = tskit.ALLELES_01 - if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: - raise ValueError( - """Alleles list could not be identified. - Please pass a list of lists of alleles of length m, - or a list of alleles (e.g. tskit.ALLELES_ACGT)""" - ) - alleles = [alleles for _ in range(ts_mirror.num_sites)] - else: - alleles, n_alleles = check_alleles(alleles, ts_mirror.num_sites) - - """Backward matrix computation based on a tree sequence.""" +def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles=None): + alleles, n_alleles = get_site_alleles(ts, h, alleles) ba = BackwardAlgorithm( - ts_mirror, + ts, rho, mu, alleles, n_alleles, - normalisation_factor, precision=precision, ) - return ba.run_backward(h) + return ba.run(h, normalisation_factor) def ls_viterbi_tree( h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False ): - if alleles is None: - n_alleles = np.int8( - [ - len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) - for j in range(ts.num_sites) - ] - ) - alleles = tskit.ALLELES_ACGT - if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: - alleles = tskit.ALLELES_01 - if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: - raise ValueError( - """Alleles list could not be identified. - Please pass a list of lists of alleles of length m, - or a list of alleles (e.g. tskit.ALLELES_ACGT)""" - ) - alleles = [alleles for _ in range(ts.num_sites)] - else: - alleles, n_alleles = check_alleles(alleles, ts.num_sites) - """ - Viterbi path computation based on a tree sequence. - """ + alleles, n_alleles = get_site_alleles(ts, h, alleles) va = ViterbiAlgorithm( ts, rho, @@ -830,14 +741,13 @@ def ls_viterbi_tree( precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, ) - return va.run_forward(h) + return va.run(h) class LSBase: """Superclass of Li and Stephens tests.""" def example_haplotypes(self, ts): - H = ts.genotype_matrix() s = H[:, 0].reshape(1, H.shape[0]) H = H[:, 1:] @@ -874,13 +784,17 @@ def example_parameters_haplotypes(self, ts, seed=42): for s in haplotypes: yield n, H, s, r, mu - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] + # FIXME removing these as tests are abominably slow. + # We'll be refactoring all this to use pytest anyway, so let's not + # worry too much about coverage for now. + # # Mixture of random and extremes + # rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] + # mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - for s, r, mu in itertools.product(haplotypes, rs, mus): - r[0] = 0 - yield n, H, s, r, mu + # import itertools + # for s, r, mu in itertools.product(haplotypes, rs, mus): + # r[0] = 0 + # yield n, H, s, r, mu def assertAllClose(self, A, B): """Assert that all entries of two matrices are 'close'""" @@ -1029,13 +943,18 @@ class TestForwardHapTree(FBAlgorithmBase): def verify(self, ts): for n, H, s, r, mu in self.example_parameters_haplotypes(ts): for scale_mutation in [False, True]: - F, c, ll = ls.forwards( - H, - s, - r, - mutation_rate=mu, - scale_mutation_based_on_n_alleles=scale_mutation, - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Warning from lshmm: + # Passed a vector of mutation rates, but rescaling each mutation + # rate conditional on the number of alleles + F, c, ll = ls.forwards( + H, + s, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=scale_mutation, + ) # Note, need to remove the first sample from the ts, and ensure # that invariant sites aren't removed. ts_check = ts.simplify(range(1, n + 1), filter_sites=False) @@ -1075,16 +994,15 @@ def verify(self, ts): c_f = ls_forward_tree(s[0, :], ts_check, r, mu) ll_tree = np.sum(np.log10(c_f.normalisation_factor)) - ts_check_mirror = mirror_coordinates(ts_check) - r_flip = np.flip(r) c_b = ls_backward_tree( - np.flip(s[0, :]), - ts_check_mirror, - r_flip, - np.flip(mu), - np.flip(c_f.normalisation_factor), + s[0, :], + ts_check, + r, + mu, + c_f.normalisation_factor, ) - B_tree = np.flip(c_b.decode(), axis=0) + B_tree = c_b.decode() + F_tree = c_f.decode() self.assertAllClose(B, B_tree) @@ -1118,3 +1036,247 @@ def verify(self, ts): scale_mutation_based_on_n_alleles=False, ) self.assertAllClose(ll, ll_check) + + +# TODO add params to run the various checks +def check_viterbi(ts, h, recombination=None, mutation=None): + h = np.array(h).astype(np.int8) + m = ts.num_sites + assert len(h) == m + if recombination is None: + recombination = np.zeros(ts.num_sites) + 1e-9 + if mutation is None: + mutation = np.zeros(ts.num_sites) + precision = 22 + + G = ts.genotype_matrix() + + path, ll = ls.viterbi( + G, + h.reshape(1, m), + recombination, + mutation_rate=mutation, + scale_mutation_based_on_n_alleles=False, + ) + assert np.isscalar(ll) + + cm = ls_viterbi_tree(h, ts, rho=recombination, mu=mutation) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + assert np.isscalar(ll_tree) + nt.assert_allclose(ll_tree, ll) + + # Check that the likelihood of the preferred path is + # the same as ll_tree (and ll). + path_tree = cm.traceback() + ll_check = ls.path_ll( + G, + h.reshape(1, m), + path_tree, + recombination, + mutation_rate=mutation, + scale_mutation_based_on_n_alleles=False, + ) + nt.assert_allclose(ll_check, ll) + + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.ViterbiMatrix(ll_ts) + ls_hmm.viterbi_matrix(h, cm_lib) + path_lib = cm_lib.traceback() + + # Not true in general, but let's see how far it goes + nt.assert_array_equal(path_lib, path_tree) + + nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor) + + return path + + +# TODO add params to run the various checks +def check_forward_matrix(ts, h, recombination=None, mutation=None): + precision = 22 + h = np.array(h).astype(np.int8) + n = ts.num_samples + m = ts.num_sites + assert len(h) == m + if recombination is None: + recombination = np.zeros(ts.num_sites) + 1e-9 + if mutation is None: + mutation = np.zeros(ts.num_sites) + + G = ts.genotype_matrix() + F, c, ll = ls.forwards( + G, + h.reshape(1, m), + recombination, + mutation_rate=mutation, + scale_mutation_based_on_n_alleles=False, + ) + assert F.shape == (m, n) + assert c.shape == (m,) + assert np.isscalar(ll) + + cm = ls_forward_tree( + h, ts, recombination, mutation, scale_mutation_based_on_n_alleles=False + ) + F2 = cm.decode() + nt.assert_allclose(F, F2) + nt.assert_allclose(c, cm.normalisation_factor) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + nt.assert_allclose(ll_tree, ll) + + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.forward_matrix(h, cm_lib) + F3 = cm_lib.decode() + + assert_compressed_matrices_equal(cm, cm_lib) + + nt.assert_allclose(F, F3) + nt.assert_allclose(c, cm_lib.normalisation_factor) + return cm_lib + + +def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): + precision = 22 + h = np.array(h).astype(np.int8) + m = ts.num_sites + assert len(h) == m + if recombination is None: + recombination = np.zeros(ts.num_sites) + 1e-9 + if mutation is None: + mutation = np.zeros(ts.num_sites) + + G = ts.genotype_matrix() + B = ls.backwards( + G, + h.reshape(1, m), + forward_cm.normalisation_factor, + recombination, + mutation_rate=mutation, + scale_mutation_based_on_n_alleles=False, + ) + + backward_cm = ls_backward_tree( + h, + ts, + recombination, + mutation, + forward_cm.normalisation_factor, + precision=precision, + ) + nt.assert_array_equal( + backward_cm.normalisation_factor, forward_cm.normalisation_factor + ) + + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) + + assert_compressed_matrices_equal(backward_cm, cm_lib) + + B_lib = cm_lib.decode() + B_tree = backward_cm.decode() + nt.assert_allclose(B_tree, B_lib) + nt.assert_allclose(B, B_lib) + + +def add_unique_sample_mutations(ts, start=0): + """ + Adds a mutation for each of the samples at equally spaced locations + along the genome. + """ + tables = ts.dump_tables() + L = int(ts.sequence_length) + assert L % ts.num_samples == 0 + gap = L // ts.num_samples + x = start + for u in ts.samples(): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, derived_state="1", node=u) + x += gap + return tables.tree_sequence() + + +class TestSingleBalancedTreeExample: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + return add_unique_sample_mutations( + tskit.Tree.generate_balanced(4, span=8).tree_sequence, + start=1, + ) + + @pytest.mark.parametrize("j", [0, 1, 2, 3]) + def test_match_sample(self, j): + ts = self.ts() + h = np.zeros(4) + h[j] = 1 + path = check_viterbi(ts, h) + nt.assert_array_equal([j, j, j, j], path) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) + + @pytest.mark.parametrize("j", [1, 2]) + def test_match_sample_missing_flanks(self, j): + ts = self.ts() + h = np.zeros(4) + h[0] = -1 + h[-1] = -1 + h[j] = 1 + path = check_viterbi(ts, h) + nt.assert_array_equal([j, j, j, j], path) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) + + def test_switch_each_sample(self): + ts = self.ts() + h = np.ones(4) + path = check_viterbi(ts, h) + nt.assert_array_equal([0, 1, 2, 3], path) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) + + def test_switch_each_sample_missing_flanks(self): + ts = self.ts() + h = np.ones(4) + h[0] = -1 + h[-1] = -1 + path = check_viterbi(ts, h) + nt.assert_array_equal([1, 1, 2, 2], path) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) + + def test_switch_each_sample_missing_middle(self): + ts = self.ts() + h = np.ones(4) + h[1:3] = -1 + path = check_viterbi(ts, h) + # Implementation of Viterbi switches at right-most position + nt.assert_array_equal([0, 3, 3, 3], path) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) + + +class TestSimulationExamples: + @pytest.mark.parametrize("n", [3, 10, 50]) + @pytest.mark.parametrize("L", [1, 10, 100]) + def test_continuous_genome(self, n, L): + ts = msprime.simulate( + n, length=L, recombination_rate=1, mutation_rate=1, random_seed=42 + ) + h = np.zeros(ts.num_sites, dtype=np.int8) + # NOTE this is a bit slow at the moment but we can disable the Python + # implementation once testing has been improved on smaller examples. + # Add ``compare_py=False``to these calls. + check_viterbi(ts, h) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index c33f159deb..5994a7a851 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2834,34 +2834,59 @@ def test_haplotype_input(self): m = ts.get_num_sites() fm = _tskit.CompressedMatrix(ts) vm = _tskit.ViterbiMatrix(ts) + norm = np.ones(m) ls_hmm = _tskit.LsHmm(ts, np.zeros(m), np.zeros(m)) for bad_size in [0, m - 1, m + 1, m + 2]: bad_array = np.zeros(bad_size, dtype=np.int8) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="haplotype array"): ls_hmm.forward_matrix(bad_array, fm) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="haplotype array"): + ls_hmm.backward_matrix(bad_array, norm, fm) + with pytest.raises(ValueError, match="haplotype array"): ls_hmm.viterbi_matrix(bad_array, vm) for bad_array in [[0.002], [[], []], None]: with pytest.raises(ValueError): ls_hmm.forward_matrix(bad_array, fm) with pytest.raises(ValueError): ls_hmm.viterbi_matrix(bad_array, vm) + with pytest.raises(ValueError): + ls_hmm.backward_matrix(bad_array, norm, fm) + + def test_norm_input(self): + ts = self.get_example_tree_sequence() + m = ts.get_num_sites() + cm = _tskit.CompressedMatrix(ts) + h = np.zeros(m, dtype=np.int32) + ls_hmm = _tskit.LsHmm(ts, np.zeros(m), np.zeros(m)) + for bad_size in [0, m - 1, m + 1, m + 2]: + bad_array = np.zeros(bad_size) + with pytest.raises(ValueError, match="forward_norm array"): + ls_hmm.backward_matrix(h, bad_array, cm) + + for bad_array in [[0.002], [[], []], None]: + with pytest.raises(ValueError): + ls_hmm.backward_matrix(h, bad_array, cm) def test_output_type_errors(self): ts = self.get_example_tree_sequence() m = ts.get_num_sites() h = np.zeros(m, dtype=np.int8) + norm = np.ones(m) ls_hmm = _tskit.LsHmm(ts, np.zeros(m), np.zeros(m)) for bad_type in [ls_hmm, None, m, []]: with pytest.raises(TypeError): ls_hmm.forward_matrix(h, bad_type) with pytest.raises(TypeError): ls_hmm.viterbi_matrix(h, bad_type) + with pytest.raises(TypeError): + ls_hmm.backward_matrix(h, norm, bad_type) other_ts = self.get_example_tree_sequence() output = _tskit.CompressedMatrix(other_ts) with pytest.raises(_tskit.LibraryError): ls_hmm.forward_matrix(h, output) + with pytest.raises(_tskit.LibraryError): + ls_hmm.backward_matrix(h, norm, output) output = _tskit.ViterbiMatrix(other_ts) with pytest.raises(_tskit.LibraryError): ls_hmm.viterbi_matrix(h, output) @@ -2910,7 +2935,7 @@ def verify_compressed_matrix(self, ts, output): assert len(item) == 2 node, value = item assert 0 <= node < ts.get_num_nodes() - assert 0 <= value <= 1 + assert value >= 0 for site in [m, m + 1, 2 * m]: with pytest.raises(ValueError): output.get_site(site) @@ -2924,6 +2949,17 @@ def test_forward_matrix(self): assert rv is None self.verify_compressed_matrix(ts, output) + def test_backward_matrix(self): + ts = self.get_example_tree_sequence() + m = ts.get_num_sites() + fm = _tskit.CompressedMatrix(ts) + bm = _tskit.CompressedMatrix(ts) + h = np.zeros(m, dtype=np.int32) + ls_hmm = _tskit.LsHmm(ts, np.zeros(m) + 0.1, np.zeros(m) + 0.1) + ls_hmm.forward_matrix(h, fm) + ls_hmm.backward_matrix(h, fm.normalisation_factor, bm) + self.verify_compressed_matrix(ts, bm) + def test_viterbi_matrix(self): ts = self.get_example_tree_sequence() m = ts.get_num_sites()