diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py index 82e628d74d..bce4844c7d 100644 --- a/python/tests/test_tree_positioning.py +++ b/python/tests/test_tree_positioning.py @@ -42,6 +42,8 @@ # tree class TreePosition: def __init__(self, ts): + self.num_trees = ts.num_trees + self.index = -1 self.direction = FORWARD self.num_edges = ts.num_edges self.sequence_length = ts.sequence_length @@ -49,13 +51,28 @@ def __init__(self, ts): self.edges_right = ts.edges_right self.edge_left_order = ts.indexes_edge_insertion_order self.edge_right_order = ts.indexes_edge_removal_order + self.edge_left_current_index = 0 + self.edge_right_current_index = 0 self.interval = [0, 0] self.in_order = None self.out_order = None self.in_range = [0, 0] self.out_range = [0, 0] + def __str__(self): + s = f"index: {self.index}\ninterval: {self.interval}\n" + s += f"direction: {self.direction}\n" + s += f"in_range: {self.in_range}\n" + s += f"out_range: {self.out_range}\n" + return s + def next(self): # NOQA: A003 + if self.index == -1: + self.direction = FORWARD + + direction_change = int(self.direction != FORWARD) + self.direction = FORWARD + M = self.num_edges in_coords = self.edges_left in_order = self.edge_left_order @@ -63,17 +80,19 @@ def next(self): # NOQA: A003 out_order = self.edge_right_order x = self.interval[1] - k = self.out_range[1] + k = self.edge_right_current_index + direction_change self.out_range[0] = k while k < M and out_coords[out_order[k]] == x: k += 1 self.out_range[1] = k + self.edge_right_current_index = k - j = self.in_range[1] + j = self.edge_left_current_index + direction_change self.in_range[0] = j while j < M and in_coords[in_order[j]] == x: j += 1 self.in_range[1] = j + self.edge_left_current_index = j left = x right = self.sequence_length @@ -86,14 +105,23 @@ def next(self): # NOQA: A003 self.interval[:] = [left, right] self.out_order = out_order self.in_order = in_order - return j < M or left < self.sequence_length + self.index += 1 + assert (self.index != self.num_trees) == (j < M or left < self.sequence_length) + if self.index == self.num_trees: + self.index = -1 + return self.index != -1 def prev(self): M = self.num_edges - if self.interval[1] == 0: + if self.index == -1: + self.index = self.num_trees self.interval[0] = self.sequence_length - self.out_range[1] = M - 1 - self.in_range[1] = M - 1 + self.edge_right_current_index = M - 1 + self.edge_left_current_index = M - 1 + self.direction = REVERSE + + direction_change = int(self.direction != REVERSE) + self.direction = REVERSE in_coords = self.edges_right in_order = self.edge_right_order @@ -101,17 +129,19 @@ def prev(self): out_order = self.edge_left_order x = self.interval[0] - k = self.out_range[1] + k = self.edge_left_current_index - direction_change self.out_range[0] = k while k >= 0 and out_coords[out_order[k]] == x: k -= 1 self.out_range[1] = k + self.edge_left_current_index = k - j = self.in_range[1] + j = self.edge_right_current_index - direction_change self.in_range[0] = j while j >= 0 and in_coords[in_order[j]] == x: j -= 1 self.in_range[1] = j + self.edge_right_current_index = j right = x left = 0 @@ -125,7 +155,11 @@ def prev(self): self.interval[:] = [left, right] self.out_order = out_order self.in_order = in_order - return j >= 0 or right > 0 + + assert self.index >= 0 + self.index -= 1 + assert (self.index != -1) == (j >= 0 or right > 0) + return self.index != -1 class StatefulTree: @@ -137,8 +171,18 @@ class StatefulTree: def __init__(self, ts): self.ts = ts self.tree_pos = TreePosition(ts) - self.index = -1 - self.parent = np.zeros(ts.num_nodes + 1, dtype=int) - 1 + self.parent = [-1 for _ in range(ts.num_nodes)] + + def __str__(self): + s = f"parent: {self.parent}\nposition:\n" + for line in str(self.tree_pos).splitlines(): + s += f"\t{line}\n" + return s + + def assert_equal(self, other): + assert self.parent == other.parent + assert self.tree_pos.index == other.tree_pos.index + assert self.tree_pos.interval == other.tree_pos.interval def next(self): # NOQA: A003 valid = self.tree_pos.next() @@ -152,14 +196,17 @@ def next(self): # NOQA: A003 c = self.ts.edges_child[e] p = self.ts.edges_parent[e] self.parent[c] = p - self.index += 1 - else: - self.index = -1 return valid + def seek_forward(self, index): + while self.tree_pos.index != index: + self.next() + + def seek_backward(self, index): + while self.tree_pos.index != index: + self.prev() + def prev(self): - if self.index == -1: - self.index = self.ts.num_trees valid = self.tree_pos.prev() if valid: for j in range(*self.tree_pos.out_range, -1): @@ -171,9 +218,6 @@ def prev(self): c = self.ts.edges_child[e] p = self.ts.edges_parent[e] self.parent[c] = p - self.index -= 1 - else: - self.index = -1 return valid @@ -274,6 +318,80 @@ def check_iters_back(ts): assert i == -1 +def check_forward_back_sweep(ts): + alg_t_output = [ + (list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts) + ] + for j in range(ts.num_trees - 1): + tree = StatefulTree(ts) + # Seek forward to j + k = 0 + while k <= j: + tree.next() + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert tree.tree_pos.interval == interval + assert parent == tree.parent + k += 1 + k = j + # And back to zero + while k >= 0: + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert tree.tree_pos.interval == interval + assert parent == tree.parent + tree.prev() + k -= 1 + + +class TestDirectionSwitching: + # 2.00┊ ┊ 4 ┊ 4 ┊ 4 ┊ + # ┊ ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ ┏┻━┓ ┊ + # 1.00┊ 3 ┊ ┃ 3 ┊ 3 ┃ ┊ 3 ┃ ┊ + # ┊ ┏━╋━┓ ┊ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 2 ┊ 0 1 2 ┊ 0 2 1 ┊ 0 1 2 ┊ + # 0 1 2 3 4 + # index 0 1 2 3 + def ts(self): + return tsutil.all_trees_ts(3) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_to_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.seek_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.seek_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_backward_to_next(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_backward(index) + tree1.next() + tree2 = StatefulTree(self.ts()) + tree2.seek_backward(index + 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.seek_forward(index + 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_next_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.seek_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.seek_backward(index - 1) + tree1.assert_equal(tree2) + + class TestAllTreesTs: @pytest.mark.parametrize("n", [2, 3, 4]) def test_forward_full(self, n): @@ -285,6 +403,11 @@ def test_back_full(self, n): ts = tsutil.all_trees_ts(n) check_iters_back(ts) + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_forward_back(self, n): + ts = tsutil.all_trees_ts(n) + check_forward_back_sweep(ts) + class TestSuiteExamples: @pytest.mark.parametrize("ts", get_example_tree_sequences())