Skip to content

Commit

Permalink
just in case
Browse files Browse the repository at this point in the history
  • Loading branch information
hfr1tz3 committed Aug 17, 2023
1 parent 8ece69c commit f18f3d1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 26 deletions.
69 changes: 45 additions & 24 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -8179,8 +8179,8 @@ def py_extend_edges():
last_num_edges = ts.num_edges
return ts

def forward_extend(self, forwards=True):
# TODO: move python version from trees.py here
def _extend(self, forwards=True):
print("forwards:", forwards)
num_edges = np.full(self.num_nodes, 0)

t = self.tables
Expand All @@ -8191,45 +8191,60 @@ def forward_extend(self, forwards=True):

# edge diff stuff
M = edges.num_rows
I = self.indexes_edge_insertion_order
O = self.indexes_edge_removal_order
if forwards:
I = self.indexes_edge_insertion_order
O = self.indexes_edge_removal_order
else:
I = np.flip(self.indexes_edge_removal_order)
O = np.flip(self.indexes_edge_insertion_order)
tj = 0
tk = 0
left = 0
# "here" will be left if fowards else right
here = 0 if forwards else self.sequence_length
edges_out = []
edges_in = []

while (tj < M) or (left < self.sequence_length):
endpoint = self.sequence_length if forwards else 0
sign = +1 if forwards else -1
near_edge = edges.left if forwards else edges.right
far_edge = edges.right if forwards else edges.left

while (tj < M) or (forwards and here < endpoint):
# clear out non-extended or postponed edges
edges_out = [[e, False] for e, x in edges_out if x]
edges_in = [[e, False] for e, x in edges_in if x]

# Find edges_out between trees
while (tk < M) and (edges.right[O[tk]] == left):
while (tk < M) and (far_edge[O[tk]] == here):
edges_out.append([O[tk], False])
num_edges[edges.parent[O[tk]]] -= 1
num_edges[edges.child[O[tk]]] -= 1
#print("Edge Out", tk, edges[O[tk]])
tk += 1
# Find edges_in between trees
while (tj < M) and (edges.left[I[tj]] == left):
while (tj < M) and (near_edge[I[tj]] == here):
edges_in.append([I[tj], False])
num_edges[edges.parent[I[tj]]] += 1
num_edges[edges.child[I[tj]]] += 1
#print("Edge In", tj, edges[I[tj]])
tj += 1

# Find smallest length right enpoint of all edges in edges_in and edges_out
# right should equal the endpoint of a T_k
right = self.sequence_length
if tk < M:
right = min(right, edges.right[O[tk]])
if tj < M:
right = min(right, edges.left[I[tj]])
# Find smallest length right endpoint of all edges in edges_in and edges_out
# there should equal the endpoint of a T_k
there = self.sequence_length if forwards else 0
if forwards:
if tk < M:
there = min(there, far_edge[O[tk]])
if tj < M:
there = min(there, near_edge[I[tj]])
else:
if tk < M:
there = max(there, far_edge[O[tk]])
if tj < M:
there = max(there, near_edge[I[tj]])
print("All Edges Out", edges_out)
print("All Edges In", edges_in)
assert np.all(num_edges >= 0)
print("-------------", left, len(edges_out), len(edges_in))
print("-------------", here, len(edges_out), len(edges_in))
for ex1 in edges_out:
#print("e1:", e1, [edges.parent[O[e1]], edges.child[O[e1]]], edges[O[e1]])
if not ex1[1]:
Expand All @@ -8245,7 +8260,7 @@ def forward_extend(self, forwards=True):
for ex_in in edges_in:
e_in = ex_in[0]
#print("ein", e_in, [edges.parent[I[e_in]], edges.child[I[e_in]]])
if edges.right[e_in] > left:
if sign * far_edge[e_in] > sign * here:
if (
edges.child[e1] == edges.child[e_in]
and edges.parent[e2] == edges.parent[e_in]
Expand All @@ -8255,14 +8270,19 @@ def forward_extend(self, forwards=True):
ex1[1] = True
ex2[1] = True
ex_in[1] = True
new_right[e1] = right
new_right[e2] = right
new_left[e_in] = right
if forwards:
new_right[e1] = there
new_right[e2] = there
new_left[e_in] = there
else:
new_left[e1] = there
new_left[e2] = there
new_right[e_in] = there
# amend num_edges: the intermediate
# node has 2 edges instead of 0
num_edges[edges.parent[e1]] += 2
# cleanup at end of loop
left = right
here = there

for j in range(edges.num_rows):
left = new_left[j]
Expand Down Expand Up @@ -8317,8 +8337,9 @@ def verify_extend_edges(self, ts, ets):
for j in (k-1, k+1):
if j < 0 or j >= len(chains):
next
this_chains = chains[j]
print(j, this_chains)
else:
this_chains = chains[j]
print(j, this_chains)
for a, b, c in this_chains:
if a in tt.nodes() and tt.parent(a) == c and b not in tt.nodes():
# the relationship a <- b <- c should still be in the tree,
Expand Down
4 changes: 2 additions & 2 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -6836,8 +6836,8 @@ def _extend(self, forwards=True):
I = self.indexes_edge_insertion_order
O = self.indexes_edge_removal_order
else:
I = np.flip(self.indexes_edge_insertion_order)
O = np.flip(self.indexes_edge_removal_order)
I = np.flip(self.indexes_edge_removal_order)
O = np.flip(self.indexes_edge_insertion_order)
tj = 0
tk = 0
# "here" will be left if fowards else right
Expand Down

0 comments on commit f18f3d1

Please sign in to comment.