diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py index 961f0810f7..39f0b5ccab 100644 --- a/python/tests/test_tree_positioning.py +++ b/python/tests/test_tree_positioning.py @@ -98,10 +98,6 @@ def seek_forward(self, index): old_left, old_right = self.tree_pos.interval self.tree_pos.seek_forward(index) left, right = self.tree_pos.interval - # print() - # print("Current interval:", old_left, old_right) - # print("New interval:", left, right) - # print("index:", index, "out_range:", self.tree_pos.out_range) for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): e = self.tree_pos.out_range.order[j] e_left = self.ts.edges_left[e] @@ -113,16 +109,9 @@ def seek_forward(self, index): assert self.parent[c] != -1 self.parent[c] = -1 assert e_left < left - # print("index:", index, "in_range:", self.tree_pos.in_range) for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop): e = self.tree_pos.in_range.order[j] if self.ts.edges_left[e] <= left < self.ts.edges_right[e]: - # print("keep", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) - # print( - # "INSERT:", - # self.ts.edge(e), - # self.ts.nodes_time[self.ts.edges_parent[e]], - # ) c = self.ts.edges_child[e] p = self.ts.edges_parent[e] self.parent[c] = p @@ -132,12 +121,34 @@ def seek_forward(self, index): # The first and last indexes in the range should always be valid # for the tree. assert a < j < b - 1 - # print("skip", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) def seek_backward(self, index): - # TODO - while self.tree_pos.index != index: - self.prev() + old_left, old_right = self.tree_pos.interval + self.tree_pos.seek_backward(index) + left, right = self.tree_pos.interval + for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop, -1): + e = self.tree_pos.out_range.order[j] + e_right = self.ts.edges_right[e] + # We only need to remove an edge if it's in the current tree, which + # can only happen if the edge's right coord is >= the current tree's + # right coordinate. + if e_right >= old_right: + c = self.ts.edges_child[e] + assert self.parent[c] != -1 + self.parent[c] = -1 + assert e_right > right + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop, -1): + e = self.tree_pos.in_range.order[j] + if self.ts.edges_right[e] >= right > self.ts.edges_left[e]: + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + else: + a = self.tree_pos.in_range.start + b = self.tree_pos.in_range.stop + # The first and last indexes in the range should always be valid + # for the tree. + assert a > j > b + 1 def iter_backward(self, index): while self.tree_pos.index != index: @@ -267,6 +278,57 @@ def check_forward_back_sweep(ts): k -= 1 +def check_seek_forward_out_range_is_empty(ts, index): + tree = StatefulTree(ts) + tree.seek_forward(index) + assert tree.tree_pos.out_range.start == tree.tree_pos.out_range.stop + tree.iter_backward(-1) + tree.seek_forward(index) + assert tree.tree_pos.out_range.start == tree.tree_pos.out_range.stop + + +def check_seek_backward_out_range_is_empty(ts, index): + tree = StatefulTree(ts) + tree.seek_backward(index) + assert tree.tree_pos.out_range.start == tree.tree_pos.out_range.stop + tree.iter_forward(-1) + tree.seek_backward(index) + assert tree.tree_pos.out_range.start == tree.tree_pos.out_range.stop + + +def check_seek_forward_from_null(ts, index): + tree1 = StatefulTree(ts) + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + +def check_seek_backward_from_null(ts, index): + tree1 = StatefulTree(ts) + tree1.seek_backward(index) + tree2 = StatefulTree(ts) + tree2.iter_backward(index) + tree1.assert_equal(tree2) + + +def check_seek_forward_from_first(ts, index): + tree1 = StatefulTree(ts) + tree1.next() + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + +def check_seek_backward_from_last(ts, index): + tree1 = StatefulTree(ts) + tree1.prev() + tree1.seek_backward(index) + tree2 = StatefulTree(ts) + tree2.iter_backward(index) + + class TestDirectionSwitching: # 2.00┊ ┊ 4 ┊ 4 ┊ 4 ┊ # ┊ ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ ┏┻━┓ ┊ @@ -278,16 +340,22 @@ class TestDirectionSwitching: def ts(self): return tsutil.all_trees_ts(3) + @pytest.mark.parametrize("index", [0, 1, 2, 3]) + def test_iter_backward_matches_iter_forward(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.iter_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_backward(index) + tree1.assert_equal(tree2) + @pytest.mark.parametrize("index", [1, 2, 3]) - def test_forward_to_prev(self, index): + def test_prev_from_seek_forward(self, index): tree1 = StatefulTree(self.ts()) - tree1.iter_forward(index) + tree1.seek_forward(index) tree1.prev() tree2 = StatefulTree(self.ts()) - tree2.iter_forward(index - 1) - tree1.assert_equal(tree2) - tree2 = StatefulTree(self.ts()) - tree2.iter_backward(index - 1) + tree2.seek_forward(index - 1) tree1.assert_equal(tree2) @pytest.mark.parametrize("index", [1, 2, 3]) @@ -300,57 +368,90 @@ def test_seek_forward_from_prev(self, index): tree2.iter_forward(index) tree1.assert_equal(tree2) - @pytest.mark.parametrize("index", [0, 1, 2]) - def test_backward_to_next(self, index): + @pytest.mark.parametrize("index", [0, 1, 2, 3]) + def test_seek_forward_from_null(self, index): + ts = self.ts() + check_seek_forward_from_null(ts, index) + + def test_seek_forward_next_null(self): tree1 = StatefulTree(self.ts()) - tree1.iter_backward(index) + tree1.seek_forward(3) tree1.next() - tree2 = StatefulTree(self.ts()) - tree2.iter_backward(index + 1) - tree1.assert_equal(tree2) - tree2 = StatefulTree(self.ts()) - tree2.iter_forward(index + 1) - tree1.assert_equal(tree2) + assert tree1.tree_pos.index == -1 + assert list(tree1.tree_pos.interval) == [0, 0] - @pytest.mark.parametrize("index", [1, 2, 3]) - def test_forward_next_prev(self, index): + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_next_from_seek_backward(self, index): tree1 = StatefulTree(self.ts()) - tree1.iter_forward(index) - tree1.prev() - tree2 = StatefulTree(self.ts()) - tree2.iter_forward(index - 1) - tree1.assert_equal(tree2) + tree1.seek_backward(index) + tree1.next() tree2 = StatefulTree(self.ts()) - tree2.iter_backward(index - 1) + tree2.seek_backward(index + 1) tree1.assert_equal(tree2) - @pytest.mark.parametrize("index", [1, 2, 3]) - def test_seek_forward_next_prev(self, index): + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_seek_backward_from_next(self, index): tree1 = StatefulTree(self.ts()) - tree1.iter_forward(index) - tree1.prev() - tree2 = StatefulTree(self.ts()) - tree2.seek_forward(index - 1) - tree1.assert_equal(tree2) + tree1.iter_backward(index) + tree1.next() + tree1.seek_backward(index) tree2 = StatefulTree(self.ts()) - tree2.iter_backward(index - 1) + tree2.iter_backward(index) tree1.assert_equal(tree2) - @pytest.mark.parametrize("index", [1, 2, 3]) - def test_seek_forward_from_null(self, index): - tree1 = StatefulTree(self.ts()) - tree1.seek_forward(index) - tree2 = StatefulTree(self.ts()) - tree2.iter_forward(index) - tree1.assert_equal(tree2) + @pytest.mark.parametrize("index", [0, 1, 2, 3]) + def test_seek_backward_from_null(self, index): + ts = self.ts() + check_seek_backward_from_null(ts, index) - def test_seek_forward_next_null(self): + def test_seek_backward_prev_null(self): tree1 = StatefulTree(self.ts()) - tree1.seek_forward(3) - tree1.next() + tree1.seek_backward(0) + tree1.prev() assert tree1.tree_pos.index == -1 assert list(tree1.tree_pos.interval) == [0, 0] + @pytest.mark.parametrize("index", [0, 1, 2, 3]) + def test_seek_forward_out_range_is_empty(self, index): + ts = self.ts() + check_seek_forward_out_range_is_empty(ts, index) + + @pytest.mark.parametrize("index", [0, 1, 2, 3]) + def test_seek_backward_out_range_is_empty(self, index): + ts = self.ts() + check_seek_backward_out_range_is_empty(ts, index) + + +class TestTreePositionStep: + def ts(self): + return tsutil.all_trees_ts(3) + + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_tree_position_step_forward(self, index): + ts = self.ts() + tree1_pos = tsutil.TreePosition(ts) + tree1_pos.seek_forward(index) + tree1_pos.step(direction=1) + tree2_pos = tsutil.TreePosition(ts) + tree2_pos.seek_forward(index + 1) + tree1_pos.assert_equal(tree2_pos) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_tree_position_step_backward(self, index): + ts = self.ts() + tree1_pos = tsutil.TreePosition(ts) + tree1_pos.seek_backward(index) + tree1_pos.step(direction=-1) + tree2_pos = tsutil.TreePosition(ts) + tree2_pos.seek_backward(index - 1) + tree1_pos.assert_equal(tree2_pos) + + def test_tree_position_step_invalid_direction(self): + ts = self.ts() + # Test for unallowed direction + with pytest.raises(ValueError, match="Direction must be FORWARD"): + tsutil.TreePosition(ts).step(direction="foo") + class TestSeeking: @tests.cached_example @@ -361,20 +462,13 @@ def ts(self): @pytest.mark.parametrize("index", range(26)) def test_seek_forward_from_null(self, index): - tree1 = StatefulTree(self.ts()) - tree1.seek_forward(index) - tree2 = StatefulTree(self.ts()) - tree2.iter_forward(index) - tree1.assert_equal(tree2) + ts = self.ts() + check_seek_forward_from_null(ts, index) @pytest.mark.parametrize("index", range(1, 26)) def test_seek_forward_from_first(self, index): - tree1 = StatefulTree(self.ts()) - tree1.next() - tree1.seek_forward(index) - tree2 = StatefulTree(self.ts()) - tree2.iter_forward(index) - tree1.assert_equal(tree2) + ts = self.ts() + check_seek_forward_from_first(ts, index) @pytest.mark.parametrize("index", range(1, 26)) def test_seek_last_from_index(self, index): @@ -386,6 +480,36 @@ def test_seek_last_from_index(self, index): tree2.prev() tree1.assert_equal(tree2) + @pytest.mark.parametrize("index", range(26)) + def test_seek_backward_from_null(self, index): + ts = self.ts() + check_seek_backward_from_null(ts, index) + + @pytest.mark.parametrize("index", range(0, 25)) + def test_seek_backward_from_last(self, index): + ts = self.ts() + check_seek_backward_from_last(ts, index) + + @pytest.mark.parametrize("index", range(0, 25)) + def test_seek_first_from_index(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.iter_backward(index) + tree1.seek_backward(0) + tree2 = StatefulTree(ts) + tree2.next() + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", range(26)) + def test_seek_forward_out_range_is_empty(self, index): + ts = self.ts() + check_seek_forward_out_range_is_empty(ts, index) + + @pytest.mark.parametrize("index", range(26)) + def test_seek_backward_out_range_is_empty(self, index): + ts = self.ts() + check_seek_backward_out_range_is_empty(ts, index) + class TestAllTreesTs: @pytest.mark.parametrize("n", [2, 3, 4]) @@ -416,11 +540,7 @@ def ts(self): @pytest.mark.parametrize("index", [1, 5, 10, 50, 100]) def test_seek_forward_from_null(self, index): ts = self.ts() - tree1 = StatefulTree(ts) - tree1.seek_forward(index) - tree2 = StatefulTree(ts) - tree2.iter_forward(index) - tree1.assert_equal(tree2) + check_seek_forward_from_null(ts, index) @pytest.mark.parametrize("num_trees", [1, 5, 10, 50, 100]) def test_seek_forward_from_mid(self, num_trees): @@ -434,6 +554,32 @@ def test_seek_forward_from_mid(self, num_trees): tree2.iter_forward(dest_index) tree1.assert_equal(tree2) + @pytest.mark.parametrize("index", [1, 5, 10, 50, 100]) + def test_seek_backward_from_null(self, index): + ts = self.ts() + check_seek_backward_from_null(ts, index) + + @pytest.mark.parametrize("num_trees", [1, 5, 10, 50, 100]) + def test_seek_backward_from_mid(self, num_trees): + ts = self.ts() + start_index = ts.num_trees // 2 + dest_index = max(start_index - num_trees, 0) + tree1 = StatefulTree(ts) + tree1.iter_backward(start_index) + tree1.seek_backward(dest_index) + tree2 = StatefulTree(ts) + tree2.iter_backward(dest_index) + + @pytest.mark.parametrize("index", [1, 5, 10, 50, 100]) + def test_seek_forward_out_range_is_empty(self, index): + ts = self.ts() + check_seek_forward_out_range_is_empty(ts, index) + + @pytest.mark.parametrize("index", [1, 5, 10, 50, 100]) + def test_seek_backward_out_range_is_empty(self, index): + ts = self.ts() + check_seek_backward_out_range_is_empty(ts, index) + def test_forward_full(self): check_iters_forward(self.ts()) @@ -453,18 +599,29 @@ def test_back_full(self, ts): @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_seek_forward_from_null(self, ts): index = ts.num_trees // 2 - tree1 = StatefulTree(ts) - tree1.seek_forward(index) - tree2 = StatefulTree(ts) - tree2.iter_forward(index) - tree1.assert_equal(tree2) + check_seek_forward_from_null(ts, index) @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_seek_forward_from_first(self, ts): index = ts.num_trees - 1 - tree1 = StatefulTree(ts) - tree1.next() - tree1.seek_forward(index) - tree2 = StatefulTree(ts) - tree2.iter_forward(index) - tree1.assert_equal(tree2) + check_seek_forward_from_first(ts, index) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_backward_from_null(self, ts): + index = ts.num_trees // 2 + check_seek_backward_from_null(ts, index) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_backward_from_last(self, ts): + index = 0 + check_seek_backward_from_last(ts, index) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_forward_out_range_is_empty(self, ts): + index = ts.num_trees // 2 + check_seek_forward_out_range_is_empty(ts, index) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_backward_out_range_is_empty(self, ts): + index = ts.num_trees // 2 + check_seek_backward_out_range_is_empty(ts, index) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index b86a159274..a5c5347d8a 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1752,6 +1752,11 @@ def __str__(self): s += f"out_range: {self.out_range}\n" return s + def assert_equal(self, other): + assert self.index == other.index + assert self.direction == other.direction + assert self.interval == other.interval + def set_null(self): self.index = -1 self.interval.left = 0 @@ -1887,6 +1892,10 @@ def seek_forward(self, index): j += 1 self.out_range.stop = j + if self.index == -1: + # No edges, so out_range should be empty + self.out_range.start = self.out_range.stop + # The range of edges we need to consider for the new tree # must have right coordinate > left j = left_current_index @@ -1904,6 +1913,76 @@ def seek_forward(self, index): self.in_range.order = left_order self.index = index + def seek_backward(self, index): + # NOTE this is still in development and not fully tested. + assert index >= 0 + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + assert index < self.ts.num_trees + self.index = self.ts.num_trees + self.interval.left = self.ts.sequence_length + self.in_range.stop = M - 1 + self.out_range.stop = M - 1 + self.direction = REVERSE + else: + assert index <= self.index + + if self.direction == REVERSE: + left_current_index = self.out_range.stop + right_current_index = self.in_range.stop + else: + left_current_index = self.in_range.stop - 1 + right_current_index = self.out_range.stop - 1 + + self.direction = REVERSE + right = breakpoints[index + 1] + + # The range of edges we need consider for removal starts + # at the current left index and ends at the first edge + # where the left coordinate is equal to the new tree's + # right coordinate. + j = left_current_index + self.out_range.start = j + # TODO This could be done with binary search + while j >= 0 and left_coords[left_order[j]] >= right: + j -= 1 + self.out_range.stop = j + + if self.index == self.ts.num_trees: + # No edges, so out_range should be empty + self.out_range.start = self.out_range.stop + + # The range of edges we need to consider for the new tree + # must have left coordinate < right + j = right_current_index + while j >= 0 and left_coords[right_order[j]] >= right: + j -= 1 + self.in_range.start = j + # We stop at the first edge with right coordinate < right + while j >= 0 and right_coords[right_order[j]] >= right: + j -= 1 + self.in_range.stop = j + + self.interval.right = right + self.interval.left = breakpoints[index] + self.out_range.order = left_order + self.in_range.order = right_order + self.index = index + + def step(self, direction): + if direction == FORWARD: + return self.next() + elif direction == REVERSE: + return self.prev() + else: + raise ValueError("Direction must be FORWARD (+1) or REVERSE (-1)") + def mean_descendants(ts, reference_sets): """