Skip to content

Commit

Permalink
Multi-arg tMRCA
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Jul 28, 2023
1 parent ab5ef8f commit cf97615
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
24 changes: 21 additions & 3 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,42 +886,60 @@ def test_minlex_postorder_multiple_roots(self):


class TestMRCA:
"""
Test both the tree.mrca and tree.tmrca methods.
"""

t = tskit.Tree.generate_balanced(3)
# 4
# ┏━┻┓
# ┃ 3
# ┃ ┏┻┓
# 0 1 2

def test_two_or_more_args(self):
assert self.t.mrca(2, 1) == 3
assert self.t.mrca(0, 1, 2) == 4
@pytest.mark.parametrize("args, expected", [((2, 1), 3), ((0, 1, 2), 4)])
def test_two_or_more_args(self, args, expected):
assert self.t.mrca(*args) == expected
assert self.t.tmrca(*args) == self.t.tree_sequence.nodes_time[expected]

def test_less_than_two_args(self):
with pytest.raises(ValueError):
self.t.mrca(1)
with pytest.raises(ValueError):
self.t.tmrca(1)

def test_no_args(self):
with pytest.raises(ValueError):
self.t.mrca()
with pytest.raises(ValueError):
self.t.tmrca()

def test_same_args(self):
assert self.t.mrca(0, 0, 0, 0) == 0
assert self.t.tmrca(0, 0, 0, 0) == self.t.tree_sequence.nodes_time[0]

def test_different_tree_levels(self):
assert self.t.mrca(0, 3) == 4
assert self.t.tmrca(0, 3) == self.t.tree_sequence.nodes_time[4]

def test_out_of_bounds_args(self):
with pytest.raises(ValueError):
self.t.mrca(0, 6)
with pytest.raises(ValueError):
self.t.tmrca(0, 6)

def test_virtual_root_arg(self):
assert self.t.mrca(0, 5) == 5
assert np.isposinf(self.t.tmrca(0, 5))

def test_multiple_roots(self):
ts = tskit.Tree.generate_balanced(10).tree_sequence
ts = ts.delete_intervals([ts.first().interval])
assert ts.first().mrca(*ts.samples()) == tskit.NULL
# We decided to raise an error for tmrca here, rather than report inf
# see https://github.com/tskit-dev/tskit/issues/2801
with pytest.raises(ValueError, match="do not share a common ancestor"):
ts.first().tmrca(0, 6)


class TestPathLength:
Expand Down
18 changes: 11 additions & 7 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def mrca(self, *args):
"""
Returns the most recent common ancestor of the specified nodes.
:param int `*args`: input node IDs, must be at least 2.
:param int `*args`: input node IDs, at least 2 arguments are required.
:return: The node ID of the most recent common ancestor of the
input nodes, or :data:`tskit.NULL` if the nodes do not share
a common ancestor in the tree.
Expand All @@ -1015,12 +1015,12 @@ def get_tmrca(self, u, v):
# Deprecated alias for tmrca
return self.tmrca(u, v)

def tmrca(self, u, v):
def tmrca(self, *args):
"""
Returns the time of the most recent common ancestor of the specified
nodes. This is equivalent to::
>>> tree.time(tree.mrca(u, v))
>>> tree.time(tree.mrca(*args))
.. note::
If you are using this method to calculate average tmrca values along the
Expand All @@ -1031,12 +1031,16 @@ def tmrca(self, u, v):
nodes, for samples at time 0 the resulting statistics will be exactly
twice the tmrca value.
:param int u: The first node.
:param int v: The second node.
:return: The time of the most recent common ancestor of u and v.
:param `*args`: input node IDs, at least 2 arguments are required.
:return: The time of the most recent common ancestor of all the nodes.
:rtype: float
:raises ValueError: If the nodes do not share a single common ancestor in this
tree (i.e., if ``tree.mrca(*args) == tskit.NULL``)
"""
return self.get_time(self.get_mrca(u, v))
mrca = self.mrca(*args)
if mrca == tskit.NULL:
raise ValueError(f"Nodes {args} do not share a common ancestor in the tree")
return self.get_time(mrca)

def get_parent(self, u):
# Deprecated alias for parent
Expand Down

0 comments on commit cf97615

Please sign in to comment.