Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 7, 2023
1 parent b54583f commit d6216ff
Showing 1 changed file with 52 additions and 1 deletion.
53 changes: 52 additions & 1 deletion python/tests/test_tree_iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Tests for tree iterator schemes. Mostly used to develop the incremental
iterator infrastructure.
"""
import numpy as np
import pytest

import tskit
Expand All @@ -37,6 +38,8 @@
REVERSE = -1


# TODO deal with direction change and calling next()/prev on the null
# tree
class TreePosition:
def __init__(self, ts):
self.direction = FORWARD
Expand Down Expand Up @@ -125,6 +128,55 @@ def prev(self):
return j >= 0 or right > 0


class StatefulTree:
"""
Just enough functionality to mimic the low-level tree implementation
for testing of forward/backward moving.
"""

def __init__(self, ts):
self.ts = ts
self.tree_pos = TreePosition(ts)
self.index = -1
self.parent = np.zeros(ts.num_nodes + 1, dtype=int) - 1

def next(self): # NOQA: A003
valid = self.tree_pos.next()
if valid:
for j in range(*self.tree_pos.out_range):
e = self.tree_pos.out_order[j]
c = self.ts.edges_child[e]
self.parent[c] = -1
for j in range(*self.tree_pos.in_range):
e = self.tree_pos.in_order[j]
c = self.ts.edges_child[e]
p = self.ts.edges_parent[e]
self.parent[c] = p
self.index += 1
else:
self.index = -1
return valid

def prev(self):
if self.index == -1:
self.index = self.ts.num_trees
valid = self.tree_pos.prev()
if valid:
for j in range(*self.tree_pos.out_range, -1):
e = self.tree_pos.out_order[j]
c = self.ts.edges_child[e]
self.parent[c] = -1
for j in range(*self.tree_pos.in_range, -1):
e = self.tree_pos.in_order[j]
c = self.ts.edges_child[e]
p = self.ts.edges_parent[e]
self.parent[c] = p
self.index -= 1
else:
self.index = -1
return valid


def check_iters_forward(ts):
alg_t_output = tsutil.algorithm_T(ts)

Expand Down Expand Up @@ -191,7 +243,6 @@ def test_forward_full(self, n):
@pytest.mark.parametrize("n", [2, 3, 4])
def test_back_full(self, n):
ts = tsutil.all_trees_ts(n)
print(ts.draw_text())
check_iters_back(ts)


Expand Down

0 comments on commit d6216ff

Please sign in to comment.