From cbf2dda36e3813ddd85ed031d2603e0a7847d7f0 Mon Sep 17 00:00:00 2001 From: peter Date: Sun, 17 Sep 2023 13:52:19 -0700 Subject: [PATCH] dont impute missing data --- c/tskit/trees.c | 10 ++- python/tests/test_extend_edges.py | 101 ++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index ca634e57ff..c69ac22fd7 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6896,11 +6896,12 @@ tsk_treeseq_extend_edges_iter( tsk_id_t *degree = tsk_calloc(num_nodes, sizeof(*degree)); tsk_id_t *out_parent = tsk_malloc(num_nodes * sizeof(*out_parent)); tsk_bool_t *keep = tsk_calloc(num_edges, sizeof(*keep)); + bool *not_sample = tsk_malloc(num_nodes * sizeof(*not_sample)); memset(&edge_list_heap, 0, sizeof(edge_list_heap)); memset(&tree_pos, 0, sizeof(tree_pos)); - if (keep == NULL || out_parent == NULL || degree == NULL) { + if (keep == NULL || out_parent == NULL || degree == NULL || not_sample == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } @@ -6919,6 +6920,10 @@ tsk_treeseq_extend_edges_iter( goto out; } + for (tj = 0; tj < (tsk_id_t) tables->nodes.num_rows; tj++) { + not_sample[tj] = ((tables->nodes.flags[tj] & TSK_NODE_IS_SAMPLE) == 0); + } + if (forwards) { near_side = edges->left; far_side = edges->right; @@ -7000,7 +7005,7 @@ tsk_treeseq_extend_edges_iter( c = edges->child[e_in]; p = out_parent[c]; p_in = edges->parent[e_in]; - while ((p != TSK_NULL) && (degree[p] == 0) && (p != p_in)) { + while ((p != TSK_NULL) && (degree[p] == 0) && (p != p_in) && not_sample[p]) { p = out_parent[p]; } if (p == p_in) { @@ -7050,6 +7055,7 @@ tsk_treeseq_extend_edges_iter( tsk_safe_free(degree); tsk_safe_free(out_parent); tsk_safe_free(keep); + tsk_safe_free(not_sample); return ret; } diff --git a/python/tests/test_extend_edges.py b/python/tests/test_extend_edges.py index 5a3e3a60ae..2a3544bdfc 100644 --- a/python/tests/test_extend_edges.py +++ b/python/tests/test_extend_edges.py @@ -100,6 +100,7 @@ def _extend(ts, forwards=True): # `out_parent` will record the sub-forest of edges-to-be-removed out_parent = np.full(ts.num_nodes, -1, dtype="int") keep = np.full(ts.num_edges, True, dtype=bool) + not_sample = [not n.is_sample() for n in ts.nodes()] edges = ts.tables.edges.copy() @@ -188,7 +189,7 @@ def _extend(ts, forwards=True): c = edges.child[e_in] p = out_parent[c] p_in = edges.parent[e_in] - while p != tskit.NULL and degree[p] == 0 and p != p_in: + while p != tskit.NULL and degree[p] == 0 and p != p_in and not_sample[p]: p = out_parent[p] if p == p_in: # we might have passed the interval that a @@ -390,7 +391,7 @@ def test_max_iter(self): eet = ets.extend_edges(max_iter=2).dump_tables() eet.assert_equals(et) - def test_simple_ex(self): + def get_simple_ex(self, samples=None): # An example where you need to go forwards *and* backwards: # 7 and 8 should be extended to the whole sequence # @@ -424,7 +425,7 @@ def test_simple_ex(self): 8: 2.0, } # (p, c, l, r) - edge_stuff = [ + edges = [ (4, 0, 0, 10), (4, 1, 0, 5), (4, 1, 7, 10), @@ -432,41 +433,97 @@ def test_simple_ex(self): (5, 2, 5, 10), (5, 3, 0, 2), (5, 3, 5, 10), + (7, 2, 2, 5), + (7, 4, 2, 5), + (8, 1, 5, 7), + (8, 5, 5, 7), + (6, 3, 2, 5), (6, 4, 0, 2), (6, 4, 5, 10), (6, 5, 0, 2), (6, 5, 7, 10), - (6, 3, 2, 5), (6, 7, 2, 5), (6, 8, 5, 7), + ] + # here is the 'right answer' (but note only with the default args) + extended_edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 10), (7, 2, 2, 5), - (7, 4, 2, 5), + (7, 4, 0, 10), (8, 1, 5, 7), - (8, 5, 5, 7), + (8, 5, 0, 10), + (6, 7, 0, 10), + (6, 8, 0, 10), ] tables = tskit.TableCollection(sequence_length=10) nodes = tables.nodes + if samples is None: + samples = [0, 1, 2, 3] for n, t in node_times.items(): - flags = tskit.NODE_IS_SAMPLE if n < 4 else 0 + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 nodes.add_row(time=t, flags=flags) - edges = tables.edges - for p, c, l, r in edge_stuff: - edges.add_row(parent=p, child=c, left=l, right=r) - tables.sort() + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) ts = tables.tree_sequence() - ets = ts.extend_edges() + tables.edges.clear() + for p, c, l, r in extended_edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = tables.tree_sequence() assert ts.num_edges == 18 assert ets.num_edges == 12 - for t in ets.trees(): - assert 7 in t.nodes() - assert 8 in t.nodes() - assert t.parent(4) == 7 - assert t.parent(7) == 6 - assert t.parent(5) == 8 - assert t.parent(3) == 5 - assert t.parent(8) == 6 + return ts, ets + + def test_simple_ex(self): + ts, right_ets = self.get_simple_ex() + ets = ts.extend_edges() + ets.tables.assert_equals(right_ets.tables) self.verify_extend_edges(ts) + def test_internal_samples(self): + # Now we should have the same but not extend 5 (where * is): + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # 7 * 7 * 7 8 7 8 + # | | ++-+ | | +-++ | | + # 4 5 4 | * 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + # + # (p, c, l, r) + edges = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (7, 2, 2, 5), + (7, 4, 0, 10), + (8, 1, 5, 7), + (8, 5, 5, 10), + (6, 3, 2, 5), + (6, 5, 0, 2), + (6, 7, 0, 10), + (6, 8, 5, 10), + ] + ts, _ = self.get_simple_ex(samples=[0, 1, 2, 3, 5]) + tables = ts.dump_tables() + tables.edges.clear() + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ets = ts.extend_edges() + ets.tables.assert_equals(tables) + # validation doesn't work with internal, incomplete samples + # (and it would be a pain to make it work) + # self.verify_extend_edges(ts) + def test_wright_fisher(self): tables = wf.wf_sim(N=5, ngens=20, num_loci=100, deep_history=False, seed=3) tables.sort() @@ -515,6 +572,10 @@ def check(self, ts): py_ts = extend_edges(ts) lib_ts = ts.extend_edges() lib_ts.tables.assert_equals(py_ts.tables) + assert np.all(ts.genotype_matrix() == lib_ts.genotype_matrix()) + sts = ts.simplify() + lib_sts = lib_ts.simplify() + lib_sts.tables.assert_equals(sts.tables, ignore_provenance=True) @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_suite_examples_defaults(self, ts):