From 603b8fa03d6ac29d14bd0bbb6594e4c8a3142692 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Wed, 3 May 2023 17:27:01 +0100 Subject: [PATCH] Generic iterator class for trees and variants Also allows left and right to be passed to the trees iterator. --- python/tests/test_genotypes.py | 12 ++- python/tests/test_highlevel.py | 56 ++++++++++- python/tskit/genotypes.py | 30 ++++++ python/tskit/trees.py | 178 ++++++++++++++++++++++++++------- 4 files changed, 236 insertions(+), 40 deletions(-) diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 329867b600..dc50385f29 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -661,7 +661,9 @@ def test_simple_case(self, ts_fixture): ts = ts_fixture test_variant = tskit.Variant(ts) test_variant.decode(1) - for v in ts.variants(left=ts.site(1).position, right=ts.site(2).position): + v_iter = ts.variants(left=ts.site(1).position, right=ts.site(2).position) + assert len(v_iter) == 1 + for v in v_iter: # should only decode the first variant assert v.site.id == 1 assert np.all(v.genotypes == test_variant.genotypes) @@ -686,7 +688,9 @@ def test_left(self, left, expected): for x in range(int(tables.sequence_length)): tables.sites.add_row(position=x, ancestral_state="A") ts = tables.tree_sequence() - positions = [var.site.position for var in ts.variants(left=left)] + v_iter = ts.variants(left=left) + assert len(v_iter) == len(expected) + positions = [var.site.position for var in v_iter] assert positions == expected @pytest.mark.parametrize( @@ -706,7 +710,9 @@ def test_right(self, right, expected): for x in range(int(tables.sequence_length)): tables.sites.add_row(position=x, ancestral_state="A") ts = tables.tree_sequence() - positions = [var.site.position for var in ts.variants(right=right)] + v_iter = ts.variants(right=right) + assert len(v_iter) == len(expected) + positions = [var.site.position for var in v_iter] assert positions == expected @pytest.mark.parametrize("bad_left", [-1, 10, 100, np.nan, np.inf, -np.inf]) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index ca67e24ce0..be433ca8ea 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1724,6 +1724,48 @@ def test_trees_interface(self): assert t.get_num_tracked_samples(0) == 0 assert list(t.samples(0)) == [0] + def test_trees_bad_left_right(self): + ts = tskit.Tree.generate_balanced(10, span=1).tree_sequence + with pytest.raises(ValueError): + ts.trees(left=0.5, right=0.5) + with pytest.raises(ValueError): + ts.trees(left=0.5, right=0.4) + with pytest.raises(ValueError): + ts.trees(left=0.5, right=1.1) + with pytest.raises(ValueError): + ts.trees(left=-0.1, right=0.1) + with pytest.raises(ValueError): + ts.trees(left=1, right=1.5) + + def test_trees_left_right_one_tree(self): + ts = tskit.Tree.generate_balanced(10).tree_sequence + tree_iterator = ts.trees(left=0.5, right=0.6) + assert len(tree_iterator) == 1 + trees = [tree.copy() for tree in tree_iterator] + assert len(trees) == 1 + tree_iterator = reversed(ts.trees(left=0.5, right=0.6)) + assert len(tree_iterator) == 1 + assert trees[0] == ts.first() + trees = [tree.copy() for tree in tree_iterator] + assert len(trees) == 1 + assert trees[0] == ts.first() + + @pytest.mark.parametrize( + "interval", [(0, 0.5), (0.4, 0.6), (0.5, np.nextafter(0.5, 1)), (0.5, 1)] + ) + def test_trees_left_right_many_trees(self, interval): + ts = msprime.simulate(5, recombination_rate=10, random_seed=1) + assert ts.num_trees > 10 + tree_iter = ts.trees(left=interval[0], right=interval[1]) + expected_length = len(tree_iter) + n_trees = 0 + for tree in ts.trees(): + # check if the tree is within the interval + if tree.interval[1] > interval[0] and tree.interval[0] < interval[1]: + n_trees += 1 + assert tree.interval == next(tree_iter).interval + assert n_trees == expected_length + @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_get_pairwise_diversity(self, ts): with pytest.raises(ValueError, match="at least one element"): @@ -2994,8 +3036,18 @@ def test_trees_params(self): ) # Skip the first param, which is `tree_sequence` and `self` respectively tree_class_params = tree_class_params[1:] - # The trees iterator has some extra (deprecated) aliases - trees_iter_params = trees_iter_params[1:-3] + # The trees iterator has some extra (deprecated) aliases at the end + num_deprecated = 3 + trees_iter_params = trees_iter_params[1:-num_deprecated] + + # The trees iterator also has left/right/copy params which aren't in __init__() + assert trees_iter_params[-1][0] == "copy" + trees_iter_params = trees_iter_params[:-1] + assert trees_iter_params[-1][0] == "right" + trees_iter_params = trees_iter_params[:-1] + assert trees_iter_params[-1][0] == "left" + trees_iter_params = trees_iter_params[:-1] + assert trees_iter_params == tree_class_params diff --git a/python/tskit/genotypes.py b/python/tskit/genotypes.py index 239e135777..e729ebf296 100644 --- a/python/tskit/genotypes.py +++ b/python/tskit/genotypes.py @@ -233,6 +233,36 @@ def decode(self, site_id) -> None: """ self._ll_variant.decode(site_id) + def next(self): # noqa A002 + """ + Decode the variant at the next site, returning True if successful, False + if the variant is already at the last site. If the variant has not yet been + decoded, decode the variant at the first site. + """ + if self._ll_variant.site_id == self.tree_sequence.num_sites - 1: + # TODO: should also set the variant to the null state + return False + if self._ll_variant.site_id == tskit.NULL: + self.decode(0) + else: + self.decode(self._ll_variant.site_id + 1) + return True + + def prev(self): + """ + Decode the variant at the previous site, returning True if successful, False + if the variant is already at the first site. If the variant has not yet been + decoded at any site, decode the variant at the last site. + """ + if self._ll_variant.site_id == 0: + # TODO: should also set the variant to the null state + return False + if self._ll_variant.site_id == tskit.NULL: + self.decode(self.tree_sequence.num_sites - 1) + else: + self.decode(self._ll_variant.site_id - 1) + return True + def copy(self) -> Variant: """ Create a copy of this Variant. Note that calling :meth:`decode` on the diff --git a/python/tskit/trees.py b/python/tskit/trees.py index cb75f711fb..23d462dce3 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -3863,15 +3863,26 @@ def load_text( return tc.tree_sequence() -class TreeIterator: - """ - Simple class providing forward and backward iteration over a tree sequence. - """ - - def __init__(self, tree): - self.tree = tree - self.more_trees = True +class ObjectIterator: + # Simple class providing forward and backward iteration over a + # mutable object with ``next()`` and ``prev()`` methods, e.g. + # a Tree or a Variant. ``interval`` allows the bounds of the + # iterator to be specified, and should already have + # been checked using _check_genomic_range(left, right) + # If ``return_copies`` is True, the iterator will return + # immutable copies of each object (this is likely to be significantly + # less efficient). + # It can be useful to define __len__ on one of these iterators, + # which e.g. allows progress bars to provide useful feedback. + + def __init__(self, obj, interval, return_copies=False): + self._obj = obj + self.min_pos = interval[0] + self.max_pos = interval[1] + self.return_copies = return_copies self.forward = True + self.started = False + self.finished = False def __iter__(self): return self @@ -3880,17 +3891,113 @@ def __reversed__(self): self.forward = False return self + def obj_left(self): + # Used to work out where to stop iterating when going backwards. + # Override with code to return the left coordinate of self.obj + raise NotImplementedError() + + def obj_right(self): + # Used to work out when to stop iterating when going forwards. + # Override with code to return the right coordinate of self.obj + raise NotImplementedError() + + def seek_to_start(self): + # Override to set the object position to self.min_pos + raise NotImplementedError() + + def seek_to_end(self): + # Override to set the object position just before self.max_pos + raise NotImplementedError() + def __next__(self): - if self.forward: - self.more_trees = self.more_trees and self.tree.next() - else: - self.more_trees = self.more_trees and self.tree.prev() - if not self.more_trees: + if not self.finished: + if not self.started: + if self.forward: + self.seek_to_start() + else: + self.seek_to_end() + self.started = True + else: + if self.forward: + if not self._obj.next() or self.obj_left() >= self.max_pos: + self.finished = True + else: + if not self._obj.prev() or self.obj_right() < self.min_pos: + self.finished = True + if self.finished: raise StopIteration() - return self.tree + return self._obj.copy() if self.return_copies else self._obj + + +class TreeIterator(ObjectIterator): + """ + An iterator over some or all of the :class:`trees` + in a :class:`TreeSequence`. + """ + + def obj_left(self): + return self._obj.interval.left + + def obj_right(self): + return self._obj.interval.right + + def seek_to_start(self): + self._obj.seek(self.min_pos) + + def seek_to_end(self): + self._obj.seek(np.nextafter(self.max_pos, -np.inf)) def __len__(self): - return self.tree.tree_sequence.num_trees + """ + The number of trees over which a newly created iterator will iterate. + """ + ts = self._obj.tree_sequence + if self.min_pos == 0 and self.max_pos == ts.sequence_length: + # a common case: don't incur the cost of searching through the breakpoints + return ts.num_trees + breaks = ts.breakpoints(as_array=True) + left_index = breaks.searchsorted(self.min_pos, side="right") + right_index = breaks.searchsorted(self.max_pos, side="left") + return right_index - left_index + 1 + + +class VariantIterator(ObjectIterator): + """ + An iterator over some or all of the :class:`variants` + in a :class:`TreeSequence`. + """ + + def __init__(self, variant, interval, copy): + super().__init__(variant, interval, copy) + if interval[0] == 0 and interval[1] == variant.tree_sequence.sequence_length: + # a common case: don't incur the cost of searching through the positions + self.min_max_sites = [0, variant.tree_sequence.num_sites] + else: + self.min_max_sites = variant.tree_sequence.sites_position.searchsorted( + interval + ) + if self.min_max_sites[0] >= self.min_max_sites[1]: + # upper bound is exclusive: we don't include the site at self.bound[1] + self.finished = True + + def obj_left(self): + return self._obj.site.position + + def obj_right(self): + return self._obj.site.position + + def seek_to_start(self): + self._obj.decode(self.min_max_sites[0]) + + def seek_to_end(self): + self._obj.decode(self.min_max_sites[1] - 1) + + def __len__(self): + """ + The number of variants (i.e. sites) over which a newly created iterator will + iterate. + """ + return self.min_max_sites[1] - self.min_max_sites[0] class SimpleContainerSequence: @@ -4077,7 +4184,7 @@ def aslist(self, **kwargs): :return: A list of the trees in this tree sequence. :rtype: list """ - return [tree.copy() for tree in self.trees(**kwargs)] + return [tree for tree in self.trees(copy=True, **kwargs)] @classmethod def load(cls, file_or_path, *, skip_tables=False, skip_reference_sequence=False): @@ -4970,6 +5077,9 @@ def trees( sample_lists=False, root_threshold=1, sample_counts=None, + left=None, + right=None, + copy=None, tracked_leaves=None, leaf_counts=None, leaf_lists=None, @@ -5001,20 +5111,31 @@ def trees( are roots. To efficiently restrict the roots of the tree to those subtending meaningful topology, set this to 2. This value is only relevant when trees have multiple roots. + :param float left: The left-most coordinate of the region over which + to iterate. Default: ``None`` treated as 0. + :param float right: The right-most coordinate of the region over which + to iterate. Default: ``None`` treated as ``.sequence_length``. This + value is exclusive, so that a tree whose ``interval.left`` is exactly + equivalent to ``right`` will not be included in the iteration. + :param bool copy: Return a immutable copy of each tree. This will be + inefficient. Default: ``None`` treated as False. :param bool sample_counts: Deprecated since 0.2.4. :return: An iterator over the Trees in this tree sequence. - :rtype: collections.abc.Iterable, :class:`Tree` + :rtype: TreeIterator """ # tracked_leaves, leaf_counts and leaf_lists are deprecated aliases # for tracked_samples, sample_counts and sample_lists respectively. # These are left over from an older version of the API when leaves # and samples were synonymous. + interval = self._check_genomic_range(left, right) if tracked_leaves is not None: tracked_samples = tracked_leaves if leaf_counts is not None: sample_counts = leaf_counts if leaf_lists is not None: sample_lists = leaf_lists + if copy is None: + copy = False tree = Tree( self, tracked_samples=tracked_samples, @@ -5022,7 +5143,7 @@ def trees( root_threshold=root_threshold, sample_counts=sample_counts, ) - return TreeIterator(tree) + return TreeIterator(tree, interval=interval, return_copies=copy) def coiterate(self, other, **kwargs): """ @@ -5309,8 +5430,8 @@ def variants( :param int right: End with the last site before this position. If ``None`` (default) assume ``right`` is the sequence length, so that the last variant corresponds to the last site in the tree sequence. - :return: An iterator over all variants in this tree sequence. - :rtype: iter(:class:`Variant`) + :return: An iterator over the specified variants in this tree sequence. + :rtype: VariantIterator """ interval = self._check_genomic_range(left, right) if impute_missing_data is not None: @@ -5327,26 +5448,13 @@ def variants( copy = True # See comments for the Variant type for discussion on why the # present form was chosen. - variant = tskit.Variant( + variant_object = tskit.Variant( self, samples=samples, isolated_as_missing=isolated_as_missing, alleles=alleles, ) - if left == 0 and right == self.sequence_length: - start = 0 - stop = self.num_sites - else: - start, stop = np.searchsorted(self.sites_position, interval) - - if copy: - for site_id in range(start, stop): - variant.decode(site_id) - yield variant.copy() - else: - for site_id in range(start, stop): - variant.decode(site_id) - yield variant + return VariantIterator(variant_object, interval=interval, copy=copy) def genotype_matrix( self,