Skip to content

Commit

Permalink
Remove edge_diffs
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 14, 2023
1 parent 774b000 commit fcd9c10
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import _tskit
import tskit
from tests import tsutil

MISSING = -1

Expand Down Expand Up @@ -112,8 +113,8 @@ def __init__(
self.N = np.zeros(ts.num_nodes, dtype=int)
# Efficiently compute the allelic state at a site
self.allelic_state = np.zeros(ts.num_nodes, dtype=int) - 1
# Diffs so we can can update T and T_index between trees.
self.edge_diffs = self.ts.edge_diffs()
# TreePosition so we can can update T and T_index between trees.
self.tree_pos = tsutil.TreePosition(ts)
self.parent = np.zeros(self.ts.num_nodes, dtype=int) - 1
self.tree = tskit.Tree(self.ts)
self.output = None
Expand Down Expand Up @@ -239,9 +240,12 @@ def update_tree(self):
parent = self.parent
T_index = self.T_index
T = self.T
_, edges_out, edges_in = next(self.edge_diffs)
self.tree_pos.next()
assert self.tree_pos.index == self.tree.index

for edge in edges_out:
for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop):
e = self.tree_pos.out_range.order[j]
edge = self.ts.edge(e)
u = edge.child
if T_index[u] == -1:
# Make sure the subtree we're detaching has an T_index-value at the root.
Expand All @@ -254,7 +258,9 @@ def update_tree(self):
)
parent[edge.child] = -1

for edge in edges_in:
for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop):
e = self.tree_pos.in_range.order[j]
edge = self.ts.edge(e)
parent[edge.child] = edge.parent
u = edge.parent
if parent[edge.parent] == -1:
Expand Down Expand Up @@ -776,12 +782,10 @@ def ls_backward_tree_mirrored(
return ba.run(h)


def ls_backward_tree(
h, ts_mirror, rho, mu, normalisation_factor, precision=30, alleles=None
):
alleles, n_alleles = get_site_alleles(ts_mirror, h, alleles)
def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles=None):
alleles, n_alleles = get_site_alleles(ts, h, alleles)
ba = BackwardAlgorithm(
ts_mirror,
ts,
rho,
mu,
alleles,
Expand Down

0 comments on commit fcd9c10

Please sign in to comment.