Skip to content

Commit

Permalink
First pass at direction switching
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 9, 2023
1 parent 3b8ee27 commit 92c1591
Showing 1 changed file with 142 additions and 19 deletions.
161 changes: 142 additions & 19 deletions python/tests/test_tree_positioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,38 +42,57 @@
# tree
class TreePosition:
def __init__(self, ts):
self.num_trees = ts.num_trees
self.index = -1
self.direction = FORWARD
self.num_edges = ts.num_edges
self.sequence_length = ts.sequence_length
self.edges_left = ts.edges_left
self.edges_right = ts.edges_right
self.edge_left_order = ts.indexes_edge_insertion_order
self.edge_right_order = ts.indexes_edge_removal_order
self.edge_left_current_index = 0
self.edge_right_current_index = 0
self.interval = [0, 0]
self.in_order = None
self.out_order = None
self.in_range = [0, 0]
self.out_range = [0, 0]

def __str__(self):
s = f"index: {self.index}\ninterval: {self.interval}\n"
s += f"direction: {self.direction}\n"
s += f"in_range: {self.in_range}\n"
s += f"out_range: {self.out_range}\n"
return s

def next(self): # NOQA: A003
if self.index == -1:
self.direction = FORWARD

direction_change = int(self.direction != FORWARD)
self.direction = FORWARD

M = self.num_edges
in_coords = self.edges_left
in_order = self.edge_left_order
out_coords = self.edges_right
out_order = self.edge_right_order
x = self.interval[1]

k = self.out_range[1]
k = self.edge_right_current_index + direction_change
self.out_range[0] = k
while k < M and out_coords[out_order[k]] == x:
k += 1
self.out_range[1] = k
self.edge_right_current_index = k

j = self.in_range[1]
j = self.edge_left_current_index + direction_change
self.in_range[0] = j
while j < M and in_coords[in_order[j]] == x:
j += 1
self.in_range[1] = j
self.edge_left_current_index = j

left = x
right = self.sequence_length
Expand All @@ -86,32 +105,43 @@ def next(self): # NOQA: A003
self.interval[:] = [left, right]
self.out_order = out_order
self.in_order = in_order
return j < M or left < self.sequence_length
self.index += 1
assert (self.index != self.num_trees) == (j < M or left < self.sequence_length)
if self.index == self.num_trees:
self.index = -1
return self.index != -1

def prev(self):
M = self.num_edges
if self.interval[1] == 0:
if self.index == -1:
self.index = self.num_trees
self.interval[0] = self.sequence_length
self.out_range[1] = M - 1
self.in_range[1] = M - 1
self.edge_right_current_index = M - 1
self.edge_left_current_index = M - 1
self.direction = REVERSE

direction_change = int(self.direction != REVERSE)
self.direction = REVERSE

in_coords = self.edges_right
in_order = self.edge_right_order
out_coords = self.edges_left
out_order = self.edge_left_order
x = self.interval[0]

k = self.out_range[1]
k = self.edge_left_current_index - direction_change
self.out_range[0] = k
while k >= 0 and out_coords[out_order[k]] == x:
k -= 1
self.out_range[1] = k
self.edge_left_current_index = k

j = self.in_range[1]
j = self.edge_right_current_index - direction_change
self.in_range[0] = j
while j >= 0 and in_coords[in_order[j]] == x:
j -= 1
self.in_range[1] = j
self.edge_right_current_index = j

right = x
left = 0
Expand All @@ -125,7 +155,11 @@ def prev(self):
self.interval[:] = [left, right]
self.out_order = out_order
self.in_order = in_order
return j >= 0 or right > 0

assert self.index >= 0
self.index -= 1
assert (self.index != -1) == (j >= 0 or right > 0)
return self.index != -1


class StatefulTree:
Expand All @@ -137,8 +171,18 @@ class StatefulTree:
def __init__(self, ts):
self.ts = ts
self.tree_pos = TreePosition(ts)
self.index = -1
self.parent = np.zeros(ts.num_nodes + 1, dtype=int) - 1
self.parent = [-1 for _ in range(ts.num_nodes)]

def __str__(self):
s = f"parent: {self.parent}\nposition:\n"
for line in str(self.tree_pos).splitlines():
s += f"\t{line}\n"
return s

def assert_equal(self, other):
assert self.parent == other.parent
assert self.tree_pos.index == other.tree_pos.index
assert self.tree_pos.interval == other.tree_pos.interval

def next(self): # NOQA: A003
valid = self.tree_pos.next()
Expand All @@ -152,14 +196,17 @@ def next(self): # NOQA: A003
c = self.ts.edges_child[e]
p = self.ts.edges_parent[e]
self.parent[c] = p
self.index += 1
else:
self.index = -1
return valid

def seek_forward(self, index):
while self.tree_pos.index != index:
self.next()

def seek_backward(self, index):
while self.tree_pos.index != index:
self.prev()

def prev(self):
if self.index == -1:
self.index = self.ts.num_trees
valid = self.tree_pos.prev()
if valid:
for j in range(*self.tree_pos.out_range, -1):
Expand All @@ -171,9 +218,6 @@ def prev(self):
c = self.ts.edges_child[e]
p = self.ts.edges_parent[e]
self.parent[c] = p
self.index -= 1
else:
self.index = -1
return valid


Expand Down Expand Up @@ -274,6 +318,80 @@ def check_iters_back(ts):
assert i == -1


def check_forward_back_sweep(ts):
alg_t_output = [
(list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts)
]
for j in range(ts.num_trees - 1):
tree = StatefulTree(ts)
# Seek forward to j
k = 0
while k <= j:
tree.next()
interval, parent = alg_t_output[k]
assert tree.tree_pos.index == k
assert tree.tree_pos.interval == interval
assert parent == tree.parent
k += 1
k = j
# And back to zero
while k >= 0:
interval, parent = alg_t_output[k]
assert tree.tree_pos.index == k
assert tree.tree_pos.interval == interval
assert parent == tree.parent
tree.prev()
k -= 1


class TestDirectionSwitching:
# 2.00┊ ┊ 4 ┊ 4 ┊ 4 ┊
# ┊ ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ ┏┻━┓ ┊
# 1.00┊ 3 ┊ ┃ 3 ┊ 3 ┃ ┊ 3 ┃ ┊
# ┊ ┏━╋━┓ ┊ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┊ ┏┻┓ ┃ ┊
# 0.00┊ 0 1 2 ┊ 0 1 2 ┊ 0 2 1 ┊ 0 1 2 ┊
# 0 1 2 3 4
# index 0 1 2 3
def ts(self):
return tsutil.all_trees_ts(3)

@pytest.mark.parametrize("index", [1, 2, 3])
def test_forward_to_prev(self, index):
tree1 = StatefulTree(self.ts())
tree1.seek_forward(index)
tree1.prev()
tree2 = StatefulTree(self.ts())
tree2.seek_forward(index - 1)
tree1.assert_equal(tree2)
tree2 = StatefulTree(self.ts())
tree2.seek_backward(index - 1)
tree1.assert_equal(tree2)

@pytest.mark.parametrize("index", [0, 1, 2])
def test_backward_to_next(self, index):
tree1 = StatefulTree(self.ts())
tree1.seek_backward(index)
tree1.next()
tree2 = StatefulTree(self.ts())
tree2.seek_backward(index + 1)
tree1.assert_equal(tree2)
tree2 = StatefulTree(self.ts())
tree2.seek_forward(index + 1)
tree1.assert_equal(tree2)

@pytest.mark.parametrize("index", [1, 2, 3])
def test_forward_next_prev(self, index):
tree1 = StatefulTree(self.ts())
tree1.seek_forward(index)
tree1.prev()
tree2 = StatefulTree(self.ts())
tree2.seek_forward(index - 1)
tree1.assert_equal(tree2)
tree2 = StatefulTree(self.ts())
tree2.seek_backward(index - 1)
tree1.assert_equal(tree2)


class TestAllTreesTs:
@pytest.mark.parametrize("n", [2, 3, 4])
def test_forward_full(self, n):
Expand All @@ -285,6 +403,11 @@ def test_back_full(self, n):
ts = tsutil.all_trees_ts(n)
check_iters_back(ts)

@pytest.mark.parametrize("n", [2, 3, 4])
def test_forward_back(self, n):
ts = tsutil.all_trees_ts(n)
check_forward_back_sweep(ts)


class TestSuiteExamples:
@pytest.mark.parametrize("ts", get_example_tree_sequences())
Expand Down

0 comments on commit 92c1591

Please sign in to comment.