From d6216ff9c40cb74511f7ffb39a74b2fbac87d583 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 7 Jul 2023 17:03:44 +0100 Subject: [PATCH] Update --- python/tests/test_tree_iters.py | 53 ++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/python/tests/test_tree_iters.py b/python/tests/test_tree_iters.py index 41e9f3c2b6..cb0fa6145d 100644 --- a/python/tests/test_tree_iters.py +++ b/python/tests/test_tree_iters.py @@ -23,6 +23,7 @@ Tests for tree iterator schemes. Mostly used to develop the incremental iterator infrastructure. """ +import numpy as np import pytest import tskit @@ -37,6 +38,8 @@ REVERSE = -1 +# TODO deal with direction change and calling next()/prev on the null +# tree class TreePosition: def __init__(self, ts): self.direction = FORWARD @@ -125,6 +128,55 @@ def prev(self): return j >= 0 or right > 0 +class StatefulTree: + """ + Just enough functionality to mimic the low-level tree implementation + for testing of forward/backward moving. + """ + + def __init__(self, ts): + self.ts = ts + self.tree_pos = TreePosition(ts) + self.index = -1 + self.parent = np.zeros(ts.num_nodes + 1, dtype=int) - 1 + + def next(self): # NOQA: A003 + valid = self.tree_pos.next() + if valid: + for j in range(*self.tree_pos.out_range): + e = self.tree_pos.out_order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range(*self.tree_pos.in_range): + e = self.tree_pos.in_order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + self.index += 1 + else: + self.index = -1 + return valid + + def prev(self): + if self.index == -1: + self.index = self.ts.num_trees + valid = self.tree_pos.prev() + if valid: + for j in range(*self.tree_pos.out_range, -1): + e = self.tree_pos.out_order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range(*self.tree_pos.in_range, -1): + e = self.tree_pos.in_order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + self.index -= 1 + else: + self.index = -1 + return valid + + def check_iters_forward(ts): alg_t_output = tsutil.algorithm_T(ts) @@ -191,7 +243,6 @@ def test_forward_full(self, n): @pytest.mark.parametrize("n", [2, 3, 4]) def test_back_full(self, n): ts = tsutil.all_trees_ts(n) - print(ts.draw_text()) check_iters_back(ts)