Skip to content

Commit

Permalink
Implement tree.ancestors
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Oct 16, 2024
1 parent 2fb6ab8 commit 6446325
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ Iterator access
.. autosummary::
Tree.nodes
Tree.ancestors
Array access
.. autosummary::
Expand Down
1 change: 1 addition & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@

- Add comma separation to all display numbers. (:user:`benjeffery`, :issue:`3017`, :pr:`3018`)

- Added ``Tree.ancestors(u)`` method. (:user:`hyanwong`, :issue:`2706`, :pr:`3021`)

--------------------
[0.5.8] - 2024-06-27
Expand Down
11 changes: 1 addition & 10 deletions python/tests/test_balance_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
# we can remove this.


def node_path(tree, u):
path = []
u = tree.parent(u)
while u != tskit.NULL:
path.append(u)
u = tree.parent(u)
return path


def sackin_index_definition(tree):
return sum(tree.depth(u) for u in tree.leaves())

Expand Down Expand Up @@ -79,7 +70,7 @@ def b2_index_definition(tree, base=10):
if tree.num_roots != 1:
raise ValueError("B2 index is only defined for trees with one root")
proba = [
np.prod([1 / tree.num_children(u) for u in node_path(tree, leaf)])
np.prod([1 / tree.num_children(u) for u in tree.ancestors(leaf)])
for leaf in tree.leaves()
]
return -sum(p * math.log(p, base) for p in proba)
Expand Down
17 changes: 17 additions & 0 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3822,6 +3822,23 @@ def test_num_children(self):
for u in tree.nodes():
assert tree.num_children(u) == len(tree.children(u))

def test_ancestors(self):
tree = tskit.Tree.generate_balanced(10, arity=3)
ancestors_arrays = {u: [] for u in np.arange(tree.tree_sequence.num_nodes)}
ancestors_arrays[-1] = []
for u in tree.nodes(order="preorder"):
parent = tree.parent(u)
if parent != tskit.NULL:
ancestors_arrays[u] = [parent] + ancestors_arrays[tree.parent(u)]
for u in tree.nodes():
assert list(tree.ancestors(u)) == ancestors_arrays[u]

def test_ancestors_empty(self):
ts = tskit.Tree.generate_comb(10).tree_sequence
tree = ts.delete_intervals([[0, 1]]).first()
for u in ts.samples():
assert len(list(tree.ancestors(u))) == 0

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_virtual_root_semantics(self, ts):
for tree in ts.trees():
Expand Down
9 changes: 9 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,15 @@ def parent_array(self):
"""
return self._parent_array

def ancestors(self, u):
"""
Returns an iterator over the ancestors of node ``u`` in this tree.
"""
u = self.parent(u)
while u != -1:
yield u
u = self.parent(u)

# Quintuply linked tree structure.

def left_child(self, u):
Expand Down

0 comments on commit 6446325

Please sign in to comment.