Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic iterator class for trees and variants #2762

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/tests/test_genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ def test_simple_case(self, ts_fixture):
ts = ts_fixture
test_variant = tskit.Variant(ts)
test_variant.decode(1)
for v in ts.variants(left=ts.site(1).position, right=ts.site(2).position):
v_iter = ts.variants(left=ts.site(1).position, right=ts.site(2).position)
assert len(v_iter) == 1
for v in v_iter:
# should only decode the first variant
assert v.site.id == 1
assert np.all(v.genotypes == test_variant.genotypes)
Expand All @@ -686,7 +688,9 @@ def test_left(self, left, expected):
for x in range(int(tables.sequence_length)):
tables.sites.add_row(position=x, ancestral_state="A")
ts = tables.tree_sequence()
positions = [var.site.position for var in ts.variants(left=left)]
v_iter = ts.variants(left=left)
assert len(v_iter) == len(expected)
positions = [var.site.position for var in v_iter]
assert positions == expected

@pytest.mark.parametrize(
Expand All @@ -706,7 +710,9 @@ def test_right(self, right, expected):
for x in range(int(tables.sequence_length)):
tables.sites.add_row(position=x, ancestral_state="A")
ts = tables.tree_sequence()
positions = [var.site.position for var in ts.variants(right=right)]
v_iter = ts.variants(right=right)
assert len(v_iter) == len(expected)
positions = [var.site.position for var in v_iter]
assert positions == expected

@pytest.mark.parametrize("bad_left", [-1, 10, 100, np.nan, np.inf, -np.inf])
Expand Down
56 changes: 54 additions & 2 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,48 @@ def test_trees_interface(self):
assert t.get_num_tracked_samples(0) == 0
assert list(t.samples(0)) == [0]

def test_trees_bad_left_right(self):
ts = tskit.Tree.generate_balanced(10, span=1).tree_sequence
with pytest.raises(ValueError):
ts.trees(left=0.5, right=0.5)
with pytest.raises(ValueError):
ts.trees(left=0.5, right=0.4)
with pytest.raises(ValueError):
ts.trees(left=0.5, right=1.1)
with pytest.raises(ValueError):
ts.trees(left=-0.1, right=0.1)
with pytest.raises(ValueError):
ts.trees(left=1, right=1.5)

def test_trees_left_right_one_tree(self):
ts = tskit.Tree.generate_balanced(10).tree_sequence
tree_iterator = ts.trees(left=0.5, right=0.6)
assert len(tree_iterator) == 1
trees = [tree.copy() for tree in tree_iterator]
assert len(trees) == 1
tree_iterator = reversed(ts.trees(left=0.5, right=0.6))
assert len(tree_iterator) == 1
assert trees[0] == ts.first()
trees = [tree.copy() for tree in tree_iterator]
assert len(trees) == 1
assert trees[0] == ts.first()

@pytest.mark.parametrize(
"interval", [(0, 0.5), (0.4, 0.6), (0.5, np.nextafter(0.5, 1)), (0.5, 1)]
)
def test_trees_left_right_many_trees(self, interval):
ts = msprime.simulate(5, recombination_rate=10, random_seed=1)
assert ts.num_trees > 10
tree_iter = ts.trees(left=interval[0], right=interval[1])
expected_length = len(tree_iter)
n_trees = 0
for tree in ts.trees():
# check if the tree is within the interval
if tree.interval[1] > interval[0] and tree.interval[0] < interval[1]:
n_trees += 1
assert tree.interval == next(tree_iter).interval
assert n_trees == expected_length

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_get_pairwise_diversity(self, ts):
with pytest.raises(ValueError, match="at least one element"):
Expand Down Expand Up @@ -2994,8 +3036,18 @@ def test_trees_params(self):
)
# Skip the first param, which is `tree_sequence` and `self` respectively
tree_class_params = tree_class_params[1:]
# The trees iterator has some extra (deprecated) aliases
trees_iter_params = trees_iter_params[1:-3]
# The trees iterator has some extra (deprecated) aliases at the end
num_deprecated = 3
trees_iter_params = trees_iter_params[1:-num_deprecated]

# The trees iterator also has left/right/copy params which aren't in __init__()
assert trees_iter_params[-1][0] == "copy"
trees_iter_params = trees_iter_params[:-1]
assert trees_iter_params[-1][0] == "right"
trees_iter_params = trees_iter_params[:-1]
assert trees_iter_params[-1][0] == "left"
trees_iter_params = trees_iter_params[:-1]

assert trees_iter_params == tree_class_params


Expand Down
30 changes: 30 additions & 0 deletions python/tskit/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,36 @@ def decode(self, site_id) -> None:
"""
self._ll_variant.decode(site_id)

def next(self): # noqa A002
"""
Decode the variant at the next site, returning True if successful, False
if the variant is already at the last site. If the variant has not yet been
decoded, decode the variant at the first site.
"""
if self._ll_variant.site_id == self.tree_sequence.num_sites - 1:
# TODO: should also set the variant to the null state
return False
if self._ll_variant.site_id == tskit.NULL:
self.decode(0)
else:
self.decode(self._ll_variant.site_id + 1)
return True

def prev(self):
"""
Decode the variant at the previous site, returning True if successful, False
if the variant is already at the first site. If the variant has not yet been
decoded at any site, decode the variant at the last site.
"""
if self._ll_variant.site_id == 0:
# TODO: should also set the variant to the null state
return False
if self._ll_variant.site_id == tskit.NULL:
self.decode(self.tree_sequence.num_sites - 1)
else:
self.decode(self._ll_variant.site_id - 1)
return True

def copy(self) -> Variant:
"""
Create a copy of this Variant. Note that calling :meth:`decode` on the
Expand Down
178 changes: 143 additions & 35 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -3863,15 +3863,26 @@ def load_text(
return tc.tree_sequence()


class TreeIterator:
"""
Simple class providing forward and backward iteration over a tree sequence.
"""

def __init__(self, tree):
self.tree = tree
self.more_trees = True
class ObjectIterator:
# Simple class providing forward and backward iteration over a
# mutable object with ``next()`` and ``prev()`` methods, e.g.
# a Tree or a Variant. ``interval`` allows the bounds of the
# iterator to be specified, and should already have
# been checked using _check_genomic_range(left, right)
# If ``return_copies`` is True, the iterator will return
# immutable copies of each object (this is likely to be significantly
# less efficient).
# It can be useful to define __len__ on one of these iterators,
# which e.g. allows progress bars to provide useful feedback.

def __init__(self, obj, interval, return_copies=False):
self._obj = obj
self.min_pos = interval[0]
self.max_pos = interval[1]
self.return_copies = return_copies
self.forward = True
self.started = False
self.finished = False

def __iter__(self):
return self
Expand All @@ -3880,17 +3891,113 @@ def __reversed__(self):
self.forward = False
return self

def obj_left(self):
# Used to work out where to stop iterating when going backwards.
# Override with code to return the left coordinate of self.obj
raise NotImplementedError()

def obj_right(self):
# Used to work out when to stop iterating when going forwards.
# Override with code to return the right coordinate of self.obj
raise NotImplementedError()

def seek_to_start(self):
# Override to set the object position to self.min_pos
raise NotImplementedError()

def seek_to_end(self):
# Override to set the object position just before self.max_pos
raise NotImplementedError()

def __next__(self):
if self.forward:
self.more_trees = self.more_trees and self.tree.next()
else:
self.more_trees = self.more_trees and self.tree.prev()
if not self.more_trees:
if not self.finished:
if not self.started:
if self.forward:
self.seek_to_start()
else:
self.seek_to_end()
self.started = True
else:
if self.forward:
if not self._obj.next() or self.obj_left() >= self.max_pos:
self.finished = True
else:
if not self._obj.prev() or self.obj_right() < self.min_pos:
self.finished = True
if self.finished:
raise StopIteration()
return self.tree
return self._obj.copy() if self.return_copies else self._obj


class TreeIterator(ObjectIterator):
"""
An iterator over some or all of the :class:`trees<Tree>`
in a :class:`TreeSequence`.
"""

def obj_left(self):
return self._obj.interval.left

def obj_right(self):
return self._obj.interval.right

def seek_to_start(self):
self._obj.seek(self.min_pos)

def seek_to_end(self):
self._obj.seek(np.nextafter(self.max_pos, -np.inf))

def __len__(self):
return self.tree.tree_sequence.num_trees
"""
The number of trees over which a newly created iterator will iterate.
"""
ts = self._obj.tree_sequence
if self.min_pos == 0 and self.max_pos == ts.sequence_length:
# a common case: don't incur the cost of searching through the breakpoints
return ts.num_trees
breaks = ts.breakpoints(as_array=True)
left_index = breaks.searchsorted(self.min_pos, side="right")
right_index = breaks.searchsorted(self.max_pos, side="left")
return right_index - left_index + 1


class VariantIterator(ObjectIterator):
"""
An iterator over some or all of the :class:`variants<Variant>`
in a :class:`TreeSequence`.
"""

def __init__(self, variant, interval, copy):
super().__init__(variant, interval, copy)
if interval[0] == 0 and interval[1] == variant.tree_sequence.sequence_length:
# a common case: don't incur the cost of searching through the positions
self.min_max_sites = [0, variant.tree_sequence.num_sites]
else:
self.min_max_sites = variant.tree_sequence.sites_position.searchsorted(
interval
)
if self.min_max_sites[0] >= self.min_max_sites[1]:
# upper bound is exclusive: we don't include the site at self.bound[1]
self.finished = True

def obj_left(self):
return self._obj.site.position

def obj_right(self):
return self._obj.site.position

def seek_to_start(self):
self._obj.decode(self.min_max_sites[0])

def seek_to_end(self):
self._obj.decode(self.min_max_sites[1] - 1)

def __len__(self):
"""
The number of variants (i.e. sites) over which a newly created iterator will
iterate.
"""
return self.min_max_sites[1] - self.min_max_sites[0]


class SimpleContainerSequence:
Expand Down Expand Up @@ -4077,7 +4184,7 @@ def aslist(self, **kwargs):
:return: A list of the trees in this tree sequence.
:rtype: list
"""
return [tree.copy() for tree in self.trees(**kwargs)]
return [tree for tree in self.trees(copy=True, **kwargs)]

@classmethod
def load(cls, file_or_path, *, skip_tables=False, skip_reference_sequence=False):
Expand Down Expand Up @@ -4970,6 +5077,9 @@ def trees(
sample_lists=False,
root_threshold=1,
sample_counts=None,
left=None,
right=None,
copy=None,
tracked_leaves=None,
leaf_counts=None,
leaf_lists=None,
Expand Down Expand Up @@ -5001,28 +5111,39 @@ def trees(
are roots. To efficiently restrict the roots of the tree to
those subtending meaningful topology, set this to 2. This value
is only relevant when trees have multiple roots.
:param float left: The left-most coordinate of the region over which
to iterate. Default: ``None`` treated as 0.
:param float right: The right-most coordinate of the region over which
to iterate. Default: ``None`` treated as ``.sequence_length``. This
value is exclusive, so that a tree whose ``interval.left`` is exactly
equivalent to ``right`` will not be included in the iteration.
:param bool copy: Return a immutable copy of each tree. This will be
inefficient. Default: ``None`` treated as False.
:param bool sample_counts: Deprecated since 0.2.4.
:return: An iterator over the Trees in this tree sequence.
:rtype: collections.abc.Iterable, :class:`Tree`
:rtype: TreeIterator
"""
# tracked_leaves, leaf_counts and leaf_lists are deprecated aliases
# for tracked_samples, sample_counts and sample_lists respectively.
# These are left over from an older version of the API when leaves
# and samples were synonymous.
interval = self._check_genomic_range(left, right)
if tracked_leaves is not None:
tracked_samples = tracked_leaves
if leaf_counts is not None:
sample_counts = leaf_counts
if leaf_lists is not None:
sample_lists = leaf_lists
if copy is None:
copy = False
tree = Tree(
self,
tracked_samples=tracked_samples,
sample_lists=sample_lists,
root_threshold=root_threshold,
sample_counts=sample_counts,
)
return TreeIterator(tree)
return TreeIterator(tree, interval=interval, return_copies=copy)

def coiterate(self, other, **kwargs):
"""
Expand Down Expand Up @@ -5309,8 +5430,8 @@ def variants(
:param int right: End with the last site before this position. If ``None``
(default) assume ``right`` is the sequence length, so that the last
variant corresponds to the last site in the tree sequence.
:return: An iterator over all variants in this tree sequence.
:rtype: iter(:class:`Variant`)
:return: An iterator over the specified variants in this tree sequence.
:rtype: VariantIterator
"""
interval = self._check_genomic_range(left, right)
if impute_missing_data is not None:
Expand All @@ -5327,26 +5448,13 @@ def variants(
copy = True
# See comments for the Variant type for discussion on why the
# present form was chosen.
variant = tskit.Variant(
variant_object = tskit.Variant(
self,
samples=samples,
isolated_as_missing=isolated_as_missing,
alleles=alleles,
)
if left == 0 and right == self.sequence_length:
start = 0
stop = self.num_sites
else:
start, stop = np.searchsorted(self.sites_position, interval)

if copy:
for site_id in range(start, stop):
variant.decode(site_id)
yield variant.copy()
else:
for site_id in range(start, stop):
variant.decode(site_id)
yield variant
return VariantIterator(variant_object, interval=interval, copy=copy)

def genotype_matrix(
self,
Expand Down