From ec992daab6ac880a3e6e2bb31e2dddfa463adbd5 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 29 Nov 2022 23:16:30 +0000 Subject: [PATCH 01/84] Fix doc build --- docs/_config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/_config.yml b/docs/_config.yml index cba8b909e9..600b1705a3 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -38,6 +38,7 @@ sphinx: - sphinx.ext.intersphinx - sphinx_issues - sphinxarg.ext + - IPython.sphinxext.ipython_console_highlighting #- sphinxcontrib.prettyspecialmethods config: From d3ae44fb33ee9c820dca215d9b9f09d02207888f Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Mon, 28 Nov 2022 15:15:01 +0000 Subject: [PATCH 02/84] Add max_num_trees --- python/tests/data/svg/ts_max_trees.svg | 455 ++++++++++++++++ .../tests/data/svg/ts_max_trees_treewise.svg | 429 +++++++++++++++ python/tests/test_drawing.py | 88 +++- python/tskit/drawing.py | 492 ++++++++++++------ python/tskit/trees.py | 7 + 5 files changed, 1309 insertions(+), 162 deletions(-) create mode 100644 python/tests/data/svg/ts_max_trees.svg create mode 100644 python/tests/data/svg/ts_max_trees_treewise.svg diff --git a/python/tests/data/svg/ts_max_trees.svg b/python/tests/data/svg/ts_max_trees.svg new file mode 100644 index 0000000000..fff638c00e --- /dev/null +++ b/python/tests/data/svg/ts_max_trees.svg @@ -0,0 +1,455 @@ + + + + + + + + + + + + + + + + + + Genome position + + + + + + + + + 15 + + + + + + 16 + + + + + + 20 + + + + + + 93 + + + + + + 98 + + + + + + + + + + + + + + + + + + + + + + + + + Time (generations) + + + + + + + 0.00 + + + + + + 0.25 + + + + + + 0.32 + + + + + + 0.56 + + + + + + 0.57 + + + + + + 1.63 + + + + + + 2.32 + + + + + + 3.06 + + + + + + 4.15 + + + + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 33 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 25 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + + + 1 + + + 8 + + + + 16 + + + 30 + + + + + + + 31 trees + + + skipped + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + + + 8 + + + 4 + + + + 6 + + + + 7 + + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 42 + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + 4 + + + + 6 + + + + 7 + + + + + + 9 + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 39 + + + + + + diff --git a/python/tests/data/svg/ts_max_trees_treewise.svg b/python/tests/data/svg/ts_max_trees_treewise.svg new file mode 100644 index 0000000000..84b1929b02 --- /dev/null +++ b/python/tests/data/svg/ts_max_trees_treewise.svg @@ -0,0 +1,429 @@ + + + + + + + + + + Genome position + + + + + + + + + 15 + + + + + + 16 + + + + + + 20 + + + + + + 93 + + + + + + 98 + + + + + + + Time (generations) + + + + + + + 0.00 + + + + + + 0.25 + + + + + + 0.32 + + + + + + 0.56 + + + + + + 0.57 + + + + + + 1.63 + + + + + + 2.32 + + + + + + 3.06 + + + + + + 4.15 + + + + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 33 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 25 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + + + 1 + + + 8 + + + + 16 + + + 30 + + + + + + + 31 trees + + + skipped + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + + + 8 + + + 4 + + + + 6 + + + + 7 + + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 42 + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + 4 + + + + 6 + + + + 7 + + + + + + 9 + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 39 + + + + + + diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index bece3bc11f..97dbe77bd5 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -1453,9 +1453,9 @@ def test_no_repr_svg(self): output._repr_svg_() -class TestDrawSvg(TestTreeDraw, xmlunittest.XmlTestMixin): +class TestDrawSvgBase(TestTreeDraw, xmlunittest.XmlTestMixin): """ - Tests the SVG tree drawing. + Base class for testing the SVG tree drawing method """ def verify_basic_svg(self, svg, width=200, height=200, num_trees=1): @@ -1496,6 +1496,12 @@ def verify_basic_svg(self, svg, width=200, height=200, num_trees=1): cls = group.attrib["class"] assert re.search(r"\broot\b", cls) + +class TestDrawSvg(TestDrawSvgBase): + """ + Simple testing for the draw_svg method + """ + def test_repr_svg(self): ts = self.get_simple_ts() svg = ts.draw_svg() @@ -1535,7 +1541,9 @@ def test_draw_to_file(self, tmp_path): def test_nonimplemented_base_class(self): ts = self.get_simple_ts() - plot = drawing.SvgPlot(ts, (100, 100), {}, "", "dummy-class", None, True, True) + plot = drawing.SvgAxisPlot( + ts, (100, 100), {}, "", "dummy-class", None, True, True + ) plot.set_spacing() with pytest.raises(NotImplementedError): plot.draw_x_axis(tick_positions=ts.breakpoints(as_array=True)) @@ -2422,6 +2430,43 @@ def test_debug_box(self): assert svg.count("outer_plotbox") == ts.num_trees + 1 assert svg.count("inner_plotbox") == ts.num_trees + 1 + @pytest.mark.parametrize("max_trees", [-1, 0, 1]) + def test_bad_max_num_trees(self, max_trees): + ts = self.get_simple_ts() + with pytest.raises(ValueError, match="at least 2"): + ts.draw_svg(max_num_trees=max_trees) + + @pytest.mark.parametrize("max_trees", [2, 4, 9]) + def test_max_num_trees(self, max_trees): + ts = msprime.sim_ancestry( + 3, sequence_length=100, recombination_rate=0.1, random_seed=1 + ) + ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1) + assert ts.num_trees > 10 + num_sites = 0 + num_unplotted_sites = 0 + svg = ts.draw_svg(max_num_trees=max_trees) + for tree in ts.trees(): + if ( + tree.index < (max_trees + 1) // 2 + or ts.num_trees - tree.index <= max_trees // 2 + ): + num_sites += tree.num_sites + assert re.search(rf"t{tree.index}[^\d]", svg) is not None + else: + assert re.search(rf"t{tree.index}[^\d]", svg) is None + num_unplotted_sites += tree.num_sites + assert num_unplotted_sites > 0 + site_strings_in_stylesheet = svg.count(".site") + assert svg.count("site") - site_strings_in_stylesheet == num_sites + self.verify_basic_svg(svg, width=200 * (max_trees + 1)) + + +class TestDrawKnownSvg(TestDrawSvgBase): + """ + Compare against known files + """ + def verify_known_svg(self, svg, filename, save=False, **kwargs): # expected SVG files can be inspected in tests/data/svg/*.svg svg = xml.dom.minidom.parseString( @@ -2752,6 +2797,43 @@ def test_known_svg_ts_xlim(self, overwrite_viz, draw_plotbox, caplog): num_trees = sum(1 for b in ts.breakpoints() if 0.051 <= b < 0.9) + 1 self.verify_known_svg(svg, "ts_x_lim.svg", overwrite_viz, width=200 * num_trees) + def test_known_max_num_trees(self, overwrite_viz, draw_plotbox, caplog): + max_trees = 5 + ts = msprime.sim_ancestry( + 3, sequence_length=100, recombination_rate=0.1, random_seed=1 + ) + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1) + assert ts.num_trees > 10 + first_break = next(ts.trees()).interval.right + # limit to just past the first tree + svg = ts.draw_svg( + max_num_trees=max_trees, + x_lim=(first_break + 0.1, ts.sequence_length - 0.1), + y_axis=True, + time_scale="log_time", + ) + self.verify_known_svg( + svg, "ts_max_trees.svg", overwrite_viz, width=200 * (max_trees + 1) + ) + + def test_known_max_num_trees_treewise(self, overwrite_viz, draw_plotbox, caplog): + max_trees = 5 + ts = msprime.sim_ancestry( + 3, sequence_length=100, recombination_rate=0.1, random_seed=1 + ) + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1) + assert ts.num_trees > 10 + first_break = next(ts.trees()).interval.right + svg = ts.draw_svg( + max_num_trees=max_trees, + x_lim=(first_break + 0.1, ts.sequence_length - 0.1), + y_axis=True, + x_scale="treewise", + ) + self.verify_known_svg( + svg, "ts_max_trees_treewise.svg", overwrite_viz, width=200 * (max_trees + 1) + ) + class TestRounding: def test_rnd(self): diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index 4097172632..3c6ad5beae 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -50,8 +50,9 @@ # constants for whether to plot a tree in a tree sequence OMIT = 1 -LEFT_CLIPPED_BIT = 2 -RIGHT_CLIPPED_BIT = 4 +LEFT_CLIP = 2 +RIGHT_CLIP = 4 +OMIT_MIDDLE = 8 @dataclass @@ -62,7 +63,7 @@ class Offsets: mutation: int = 0 -@dataclass +@dataclass(frozen=True) class Timescaling: "Class used to transform the time axis" max_time: float @@ -77,9 +78,9 @@ def __post_init__(self): if self.use_log_transform: if self.min_time < 0: raise ValueError("Cannot use a log scale if there are negative times") - self.transform = self.log_transform + super().__setattr__("transform", self.log_transform) else: - self.transform = self.linear_transform + super().__setattr__("transform", self.linear_transform) def log_transform(self, y): "Standard log transform but allowing for values of 0 by adding 1" @@ -229,10 +230,11 @@ def create_tick_labels(tick_values, decimal_places=2): return [f"{lab:.{label_precision}f}" for lab in tick_values] -def clip_ts(ts, x_min, x_max): +def clip_ts(ts, x_min, x_max, max_num_trees=None): """ Culls the edges of the tree sequence outside the limits of x_min and x_max if - necessary. + necessary, and flags internal trees for omission if there are more than + max_num_trees in the tree sequence Returns the new tree sequence using the same genomic scale, and an array specifying which trees to actually plot from it. This array contains @@ -276,6 +278,12 @@ def clip_ts(ts, x_min, x_max): if ts.num_sites > 0 and np.max(sites.position) > x_max: x_max = ts.sequence_length # Last region has sites but no edges => keep + if max_num_trees is None: + max_num_trees = np.inf + + if max_num_trees < 2: + raise ValueError("Must show at least 2 trees when clipping a tree sequence") + if (x_min > 0) or (x_max < ts.sequence_length): old_breaks = ts.breakpoints(as_array=True) offsets.tree = np.searchsorted(old_breaks, x_min, "right") - 2 @@ -303,10 +311,22 @@ def clip_ts(ts, x_min, x_max): # Which breakpoints are new ones, as a result of clipping new_breaks = np.logical_not(np.isin(ts.breakpoints(as_array=True), old_breaks)) - tree_status[new_breaks[:-1]] |= LEFT_CLIPPED_BIT - tree_status[new_breaks[1:]] |= RIGHT_CLIPPED_BIT + tree_status[new_breaks[:-1]] |= LEFT_CLIP + tree_status[new_breaks[1:]] |= RIGHT_CLIP else: tree_status = np.zeros(ts.num_trees, dtype=np.uint8) + + first_tree = 1 if tree_status[0] & OMIT else 0 + last_tree = ts.num_trees - 2 if tree_status[-1] & OMIT else ts.num_trees - 1 + num_shown_trees = last_tree - first_tree + 1 + if num_shown_trees > max_num_trees: + num_start_trees = max_num_trees // 2 + (1 if max_num_trees % 2 else 0) + num_end_trees = max_num_trees // 2 + assert num_start_trees + num_end_trees == max_num_trees + tree_status[ + (first_tree + num_start_trees) : (last_tree - num_end_trees + 1) + ] = (OMIT | OMIT_MIDDLE) + return ts, tree_status, offsets @@ -336,20 +356,27 @@ def rnd(x): return x -def referenced_nodes(ts): +def edge_and_sample_nodes(ts, omit_regions=None): """ - Return the ids of nodes which are actually plotted in this tree sequence - (i.e. do not include nodes which are not samples and not in any edge: this - happens extensively in plotting tree sequences with x_lim specified) + Return ids of nodes which are mentioned in an edge in this tree sequence or which + are samples: nodes not connected to an edge are often found if x_lim is specified. """ - ids = np.concatenate( - ( - ts.tables.edges.child, - ts.tables.edges.parent, - np.where(ts.tables.nodes.flags & NODE_IS_SAMPLE)[0], - ) + if omit_regions is None or len(omit_regions) == 0: + ids = np.concatenate((ts.edges_child, ts.edges_parent)) + else: + ids = np.array([], dtype=ts.edges_child.dtype) + edges = ts.tables.edges + assert omit_regions.shape[1] == 2 + omit_regions = omit_regions.flatten() + assert np.all(omit_regions == np.unique(omit_regions)) # Check they're in order + use_regions = np.concatenate(([0.0], omit_regions, [ts.sequence_length])) + use_regions = use_regions.reshape(-1, 2) + for left, right in use_regions: + used_edges = edges[np.logical_and(edges.left >= left, edges.right < right)] + ids = np.concatenate((ids, used_edges.child, used_edges.parent)) + return np.unique( + np.concatenate((ids, np.where(ts.nodes_flags & NODE_IS_SAMPLE)[0])) ) - return np.unique(ids) def draw_tree( @@ -469,10 +496,17 @@ def add_class(attrs_dict, classes_str): @dataclass class Plotbox: total_size: list - pad_top: float - pad_left: float - pad_bottom: float - pad_right: float + pad_top: float = 0 + pad_left: float = 0 + pad_bottom: float = 0 + pad_right: float = 0 + + def set_padding(self, top, left, bottom, right): + self.pad_top = top + self.pad_left = left + self.pad_bottom = bottom + self.pad_right = right + self._check() @property def max_x(self): @@ -507,6 +541,9 @@ def height(self): return self.bottom - self.top def __post_init__(self): + self._check() + + def _check(self): if self.width < 1 or self.height < 1: raise ValueError("Image size too small to fit") @@ -537,7 +574,92 @@ def draw(self, dwg, add_to, colour="grey"): class SvgPlot: - """The base class for plotting either a tree or a tree sequence as an SVG file""" + """ + The base class for plotting any box to canvas + """ + + text_height = 14 # May want to calculate this based on a font size + line_height = text_height * 1.2 # allowing padding above and below a line + + def __init__( + self, + size, + svg_class, + root_svg_attributes=None, + canvas_size=None, + ): + """ + Creates self.drawing, an svgwrite.Drawing object for further use, and populates + it with a base group. The root_groups will be populated with + items that can be accessed from the outside, such as the plotbox, axes, etc. + """ + + if root_svg_attributes is None: + root_svg_attributes = {} + if canvas_size is None: + canvas_size = size + dwg = svgwrite.Drawing(size=canvas_size, debug=True, **root_svg_attributes) + + self.image_size = size + self.plotbox = Plotbox(size) + self.root_groups = {} + self.svg_class = svg_class + self.timescaling = None + self.root_svg_attributes = root_svg_attributes + self.dwg_base = dwg.add(dwg.g(class_=svg_class)) + self.drawing = dwg + + def get_plotbox(self): + """ + Get the svgwrite plotbox, creating it if necessary. + """ + if "plotbox" not in self.root_groups: + dwg = self.drawing + self.root_groups["plotbox"] = self.dwg_base.add(dwg.g(class_="plotbox")) + return self.root_groups["plotbox"] + + def add_text_in_group(self, text, add_to, pos, group_class=None, **kwargs): + """ + Add the text to the elem within a group; allows text rotations to work smoothly, + otherwise, if x & y parameters are used to position text, rotations applied to + the text tag occur around the (0,0) point of the containing group + """ + dwg = self.drawing + group_attributes = {"transform": f"translate({rnd(pos[0])} {rnd(pos[1])})"} + if group_class is not None: + group_attributes["class_"] = group_class + grp = add_to.add(dwg.g(**group_attributes)) + grp.add(dwg.text(text, **kwargs)) + + +class SvgSkippedPlot(SvgPlot): + def __init__( + self, + size, + num_skipped, + ): + super().__init__( + size, + svg_class="skipped", + ) + container = self.get_plotbox() + x = self.plotbox.width / 2 + y = self.plotbox.height / 2 + self.add_text_in_group( + f"{num_skipped} trees", + container, + (x, y - self.line_height / 2), + text_anchor="middle", + ) + self.add_text_in_group( + "skipped", container, (x, y + self.line_height / 2), text_anchor="middle" + ) + + +class SvgAxisPlot(SvgPlot): + """ + The class used for plotting either a tree or a tree sequence as an SVG file + """ standard_style = ( ".background path {fill: #808080; fill-opacity: 0}" @@ -546,6 +668,7 @@ class SvgPlot: ".x-axis .tick .lab {font-weight: bold; dominant-baseline: hanging}" ".axes, .tree {font-size: 14px; text-anchor: middle}" ".axes line, .edge {stroke: black; fill: none}" + ".axes .ax-skip {stroke-dasharray: 4}" ".y-axis .grid {stroke: #FAFAFA}" ".node > .sym {fill: black; stroke: none}" ".site > .sym {stroke: black}" @@ -561,8 +684,6 @@ class SvgPlot: ) # TODO: we may want to make some of the constants below into parameters - text_height = 14 # May want to calculate this based on a font size - line_height = text_height * 1.2 # allowing padding above and below a line root_branch_fraction = 1 / 8 # Rel root branch len, unless it has a timed mutation default_tick_length = 5 default_tick_length_site = 10 @@ -587,27 +708,18 @@ def __init__( omit_sites=None, canvas_size=None, ): - """ - Creates self.drawing, an svgwrite.Drawing object for further use, and populates - it with a stylesheet and base group. The root_groups will be populated with - items that can be accessed from the outside, such as the plotbox, axes, etc. - """ + super().__init__( + size, + svg_class, + root_svg_attributes, + canvas_size, + ) self.ts = ts - self.image_size = size - self.svg_class = svg_class - if root_svg_attributes is None: - root_svg_attributes = {} - if canvas_size is None: - canvas_size = size - self.root_svg_attributes = root_svg_attributes - dwg = svgwrite.Drawing(size=canvas_size, debug=True, **root_svg_attributes) + dwg = self.drawing # Put all styles in a single stylesheet (required for Inkscape 0.92) style = self.standard_style + ("" if style is None else style) dwg.defs.add(dwg.style(style)) - self.dwg_base = dwg.add(dwg.g(class_=svg_class)) - self.root_groups = {} self.debug_box = debug_box - self.drawing = dwg self.time_scale = check_time_scale(time_scale) self.y_axis = y_axis self.x_axis = x_axis @@ -626,29 +738,6 @@ def __init__( self.omit_sites = omit_sites self.mutations_outside_tree = set() # mutations in here get an additional class - def get_plotbox(self): - """ - Get the svgwrite plotbox (contains the tree(s) but not axes etc), creating it - if necessary. - """ - if "plotbox" not in self.root_groups: - dwg = self.drawing - self.root_groups["plotbox"] = self.dwg_base.add(dwg.g(class_="plotbox")) - return self.root_groups["plotbox"] - - def add_text_in_group(self, text, add_to, pos, group_class=None, **kwargs): - """ - Add the text to the elem within a group; allows text rotations to work smoothly, - otherwise, if x & y parameters are used to position text, rotations applied to - the text tag occur around the (0,0) point of the containing group - """ - dwg = self.drawing - group_attributes = {"transform": f"translate({rnd(pos[0])} {rnd(pos[1])})"} - if group_class is not None: - group_attributes["class_"] = group_class - grp = add_to.add(dwg.g(**group_attributes)) - grp.add(dwg.text(text, **kwargs)) - def set_spacing(self, top=0, left=0, bottom=0, right=0): """ Set edges, but allow space for axes etc @@ -663,7 +752,7 @@ def set_spacing(self, top=0, left=0, bottom=0, right=0): bottom += self.x_axis_offset if self.y_axis: left = self.y_axis_offset # Override user-provided, so y-axis is at x=0 - self.plotbox = Plotbox(self.image_size, top, left, bottom, right) + self.plotbox.set_padding(top, left, bottom, right) if self.debug_box: self.root_groups["debug"] = self.dwg_base.add( self.drawing.g(class_="debug") @@ -682,9 +771,12 @@ def draw_x_axis( tick_length_lower=default_tick_length, tick_length_upper=None, # If None, use the same as tick_length_lower site_muts=None, # A dict of site id => mutation to plot as ticks on the x axis + alternate_dash_positions=None, # Where to alternate the axis from solid to dash ): if not self.x_axis and not self.x_label: return + if alternate_dash_positions is None: + alternate_dash_positions = np.array([]) dwg = self.drawing axes = self.get_axes() x_axis = axes.add(dwg.g(class_="x-axis")) @@ -702,7 +794,21 @@ def draw_x_axis( if tick_length_upper is None: tick_length_upper = tick_length_lower y = rnd(self.plotbox.max_y - self.x_axis_offset) - x_axis.add(dwg.line((self.plotbox.left, y), (self.plotbox.right, y))) + dash_locs = np.concatenate( + ( + [self.plotbox.left], + self.x_transform(alternate_dash_positions), + [self.plotbox.right], + ) + ) + for i, (x1, x2) in enumerate(zip(dash_locs[:-1], dash_locs[1:])): + x_axis.add( + dwg.line( + (rnd(x1), y), + (rnd(x2), y), + class_="ax-skip" if i % 2 else "ax-line", + ) + ) if tick_positions is not None: if tick_labels is None or isinstance(tick_labels, np.ndarray): if tick_labels is None: @@ -790,7 +896,7 @@ def draw_y_axis( transform="translate(11) rotate(-90)", ) if self.y_axis: - y_axis.add(dwg.line((x, rnd(lower)), (x, rnd(upper)))) + y_axis.add(dwg.line((x, rnd(lower)), (x, rnd(upper)), class_="ax-line")) ticks_group = y_axis.add(dwg.g(class_="ticks")) for y, label in ticks.items(): tick = ticks_group.add( @@ -870,7 +976,7 @@ def x_transform(self, x): ) -class SvgTreeSequence(SvgPlot): +class SvgTreeSequence(SvgAxisPlot): """ A class to draw a tree sequence in SVG format. @@ -906,6 +1012,7 @@ def __init__( mutation_label_attrs=None, tree_height_scale=None, max_tree_height=None, + max_num_trees=None, **kwargs, ): if max_time is None and max_tree_height is not None: @@ -923,10 +1030,13 @@ def __init__( FutureWarning, ) x_lim = check_x_lim(x_lim, max_x=ts.sequence_length) - ts, self.tree_status, offsets = clip_ts(ts, x_lim[0], x_lim[1]) - num_trees = int(np.sum((self.tree_status & OMIT) != OMIT)) + ts, self.tree_status, offsets = clip_ts(ts, x_lim[0], x_lim[1], max_num_trees) + + use_tree = self.tree_status & OMIT == 0 + use_skipped = np.append(np.diff(self.tree_status & OMIT_MIDDLE == 0) == 1, 0) + num_plotboxes = np.sum(np.logical_or(use_tree, use_skipped)) if size is None: - size = (200 * num_trees, 200) + size = (200 * int(num_plotboxes), 200) if max_time is None: max_time = "ts" if min_time is None: @@ -954,53 +1064,68 @@ def __init__( if force_root_branch is None: force_root_branch = any( any(tree.parent(mut.node) == NULL for mut in tree.mutations()) - for tree in ts.trees() + for tree, use in zip(ts.trees(), use_tree) + if use ) # TODO add general padding arguments following matplotlib's terminology. self.set_spacing(top=0, left=20, bottom=10, right=20) - svg_trees = [ - SvgTree( - tree, - (self.plotbox.width / num_trees, self.plotbox.height), - time_scale=time_scale, - node_labels=node_labels, - mutation_labels=mutation_labels, - order=order, - force_root_branch=force_root_branch, - symbol_size=symbol_size, - max_time=max_time, - min_time=min_time, - node_attrs=node_attrs, - mutation_attrs=mutation_attrs, - edge_attrs=edge_attrs, - node_label_attrs=node_label_attrs, - mutation_label_attrs=mutation_label_attrs, - offsets=offsets, - # Do not plot axes on these subplots - **kwargs, # pass though e.g. debug boxes - ) - for status, tree in zip(self.tree_status, ts.trees()) - if (status & OMIT) != OMIT - ] + subplot_size = (self.plotbox.width / num_plotboxes, self.plotbox.height) + subplots = [] + for tree, use, summary in zip(ts.trees(), use_tree, use_skipped): + if use: + subplots.append( + SvgTree( + tree, + size=subplot_size, + time_scale=time_scale, + node_labels=node_labels, + mutation_labels=mutation_labels, + order=order, + force_root_branch=force_root_branch, + symbol_size=symbol_size, + max_time=max_time, + min_time=min_time, + node_attrs=node_attrs, + mutation_attrs=mutation_attrs, + edge_attrs=edge_attrs, + node_label_attrs=node_label_attrs, + mutation_label_attrs=mutation_label_attrs, + offsets=offsets, + # Do not plot axes on these subplots + **kwargs, # pass though e.g. debug boxes + ) + ) + last_used_index = tree.index + elif summary: + subplots.append( + SvgSkippedPlot( + size=subplot_size, num_skipped=tree.index - last_used_index + ) + ) y = self.plotbox.top - self.tree_plotbox = svg_trees[0].plotbox + self.tree_plotbox = subplots[0].plotbox + tree_is_used, breaks, skipbreaks = self.find_used_trees() self.draw_x_axis( x_scale, + tree_is_used, + breaks, + skipbreaks, tick_length_lower=self.default_tick_length, # TODO - parameterize tick_length_upper=self.default_tick_length_site, # TODO - parameterize ) y_low = self.tree_plotbox.bottom if y_axis is not None: - self.timescaling = svg_trees[0].timescaling - for svg_tree in svg_trees: - if self.timescaling != svg_tree.timescaling: - raise ValueError( - "Can't draw a tree sequence Y axis if trees vary in timescale" - ) + tscales = {s.timescaling for s in subplots if s.timescaling} + if len(tscales) > 1: + raise ValueError( + "Can't draw a tree sequence Y axis if trees vary in timescale" + ) + self.timescaling = tscales.pop() y_low = self.timescaling.transform(self.timescaling.min_time) if y_ticks is None: - y_ticks = np.unique(ts.tables.nodes.time[referenced_nodes(ts)]) + used_nodes = edge_and_sample_nodes(ts, breaks[skipbreaks]) + y_ticks = np.unique(ts.nodes_time[used_nodes]) if self.time_scale == "rank": # Ticks labelled by time not rank y_ticks = dict(enumerate(y_ticks)) @@ -1013,78 +1138,128 @@ def __init__( gridlines=y_gridlines, ) - tree_x = self.plotbox.left - trees = self.get_plotbox() # Top-level TS plotbox contains all trees - trees["class"] = trees["class"] + " trees" - for svg_tree in svg_trees: - tree = trees.add( + subplot_x = self.plotbox.left + container = self.get_plotbox() # Top-level TS plotbox contains all trees + container["class"] = container["class"] + " trees" + for subplot in subplots: + svg_subplot = container.add( self.drawing.g( - class_=svg_tree.svg_class, transform=f"translate({rnd(tree_x)} {y})" + class_=subplot.svg_class, + transform=f"translate({rnd(subplot_x)} {y})", ) ) - for svg_items in svg_tree.root_groups.values(): - tree.add(svg_items) - tree_x += svg_tree.image_size[0] - assert self.tree_plotbox == svg_tree.plotbox + for svg_items in subplot.root_groups.values(): + svg_subplot.add(svg_items) + subplot_x += subplot.image_size[0] + + def find_used_trees(self): + """ + Return a boolean array of which trees are actually plotted, + a list of which breakpoints are used to transition between plotted trees, + and a 2 x n array (often n=0) of indexes into these breakpoints delimiting + the regions that should be plotted as "skipped" + """ + tree_is_used = (self.tree_status & OMIT) != OMIT + break_used_as_tree_left = np.append(tree_is_used, False) + break_used_as_tree_right = np.insert(tree_is_used, 0, False) + break_used = np.logical_or(break_used_as_tree_left, break_used_as_tree_right) + all_breaks = self.ts.breakpoints(True) + used_breaks = all_breaks[break_used] + mark_skip_transitions = np.concatenate( + ([False], np.diff(self.tree_status & OMIT_MIDDLE) != 0, [False]) + ) + skipregion_indexes = np.where(mark_skip_transitions[break_used])[0] + assert len(skipregion_indexes) % 2 == 0 # all skipped regions have start, end + return tree_is_used, used_breaks, skipregion_indexes.reshape((-1, 2)) def draw_x_axis( self, x_scale, - tick_length_lower=SvgPlot.default_tick_length, - tick_length_upper=SvgPlot.default_tick_length_site, + tree_is_used, + breaks, + skipbreaks, + tick_length_lower=SvgAxisPlot.default_tick_length, + tick_length_upper=SvgAxisPlot.default_tick_length_site, ): """ - Add extra functionality to the original draw_x_axis method in SvgPlot, mainly + Add extra functionality to the original draw_x_axis method in SvgAxisPlot, to account for the background shading that is displayed in a tree sequence + and in case trees are omitted from the middle of the tree sequence """ if not self.x_axis and not self.x_label: return - left_break_status = np.append(self.tree_status, OMIT) - right_break_status = np.insert(self.tree_status, 0, OMIT) - use_left = (left_break_status & OMIT) != OMIT - use_right = (right_break_status & OMIT) != OMIT - all_breaks = self.ts.breakpoints(True) - breaks = all_breaks[np.logical_or(use_left, use_right)] if x_scale == "physical": - # Assume the trees are simply concatenated end-to-end - self.x_transform = ( - lambda x: self.plotbox.left - + (x - breaks[0]) / (breaks[-1] - breaks[0]) * self.plotbox.width + # In a tree sequence plot, the x_transform is used for the ticks, background + # shading positions, and sites along the x-axis. Each tree will have its own + # separate x_transform function for node positions within the tree. + + # For a plot with a break on the x-axis (representing "skipped" trees), the + # x_transform is a piecewise function. We need to identify the breakpoints + # where the x-scale transitions from the standard scale to the scale(s) used + # within a skipped region + + skipregion_plot_width = self.tree_plotbox.width + skipregion_span = np.diff(breaks[skipbreaks]).T[0] + std_scale = ( + self.plotbox.width - skipregion_plot_width * len(skipregion_span) + ) / (breaks[-1] - breaks[0] - np.sum(skipregion_span)) + skipregion_pos = breaks[skipbreaks].flatten() + genome_pos = np.concatenate(([breaks[0]], skipregion_pos, [breaks[-1]])) + plot_step = np.full(len(genome_pos) - 1, skipregion_plot_width) + plot_step[::2] = std_scale * np.diff(genome_pos)[::2] + plot_pos = np.cumsum(np.insert(plot_step, 0, self.plotbox.left)) + # Convert to slope + intercept form + slope = np.diff(plot_pos) / np.diff(genome_pos) + intercept = plot_pos[1:] - slope * genome_pos[1:] + self.x_transform = lambda y: ( + y * slope[np.searchsorted(skipregion_pos, y)] + + intercept[np.searchsorted(skipregion_pos, y)] ) + tick_positions = breaks + site_muts = { + s.id: s.mutations + for tree, use in zip(self.ts.trees(), tree_is_used) + for s in tree.sites() + if use + } + self.shade_background( breaks, tick_length_lower, self.tree_plotbox.max_x, self.plotbox.pad_bottom + self.tree_plotbox.pad_bottom, ) - site_muts = {s.id: s.mutations for s in self.ts.sites()} - # omit tick on LHS for trees that have been clipped on left, and same on RHS - use_left = np.logical_and( - use_left, (left_break_status & LEFT_CLIPPED_BIT) != LEFT_CLIPPED_BIT - ) - use_right = np.logical_and( - use_right, (right_break_status & RIGHT_CLIPPED_BIT) != RIGHT_CLIPPED_BIT - ) - super().draw_x_axis( - tick_positions=all_breaks[np.logical_or(use_left, use_right)], - tick_length_lower=tick_length_lower, - tick_length_upper=tick_length_upper, - site_muts=site_muts, - ) - else: - # No background shading needed if x_scale is "treewise" + + # For a treewise plot, the only time the x_transform is used is to apply + # to tick positions, so simply use positions 0..num_used_breaks for the + # positions, and a simple transform self.x_transform = ( lambda x: self.plotbox.left + x / (len(breaks) - 1) * self.plotbox.width ) - super().draw_x_axis( - tick_positions=np.arange(len(breaks)), - tick_labels=breaks, - tick_length_lower=tick_length_lower, - ) + tick_positions = np.arange(len(breaks)) + + site_muts = None # It doesn't make sense to plot sites for "treewise" plots + tick_length_upper = None # No sites plotted, so use the default upper tick + + # NB: no background shading needed if x_scale is "treewise" + + skipregion_pos = skipbreaks.flatten() + + first_tick = 1 if np.any(self.tree_status[tree_is_used] & LEFT_CLIP) else 0 + last_tick = -1 if np.any(self.tree_status[tree_is_used] & RIGHT_CLIP) else None + + super().draw_x_axis( + tick_positions=tick_positions[first_tick:last_tick], + tick_labels=breaks[first_tick:last_tick], + tick_length_lower=tick_length_lower, + tick_length_upper=tick_length_upper, + site_muts=site_muts, + alternate_dash_positions=skipregion_pos, + ) -class SvgTree(SvgPlot): +class SvgTree(SvgAxisPlot): """ A class to draw a tree in SVG format. @@ -1335,8 +1510,8 @@ def assign_y_coordinates( max_time, min_time, force_root_branch, - bottom_space=SvgPlot.line_height, - top_space=SvgPlot.line_height, + bottom_space=SvgAxisPlot.line_height, + top_space=SvgAxisPlot.line_height, ): """ Create a self.node_height dict, a self.timescaling instance and @@ -1346,8 +1521,8 @@ def assign_y_coordinates( """ max_time = check_max_time(max_time, self.time_scale != "rank") min_time = check_min_time(min_time, self.time_scale != "rank") - node_time = self.ts.tables.nodes.time - mut_time = self.ts.tables.mutations.time + node_time = self.ts.nodes_time + mut_time = self.ts.mutations_time root_branch_len = 0 if self.time_scale == "rank": t = np.zeros_like(node_time) @@ -1358,7 +1533,8 @@ def assign_y_coordinates( else: # only rank the nodes that are actually referenced in the edge table # (non-referenced nodes could occur if the user specifies x_lim values) - use_time = referenced_nodes(self.ts) + # However, we do include nodes in trees that have been skipped + use_time = edge_and_sample_nodes(self.ts) t[use_time] = node_time[use_time] node_time = t times = np.unique(node_time[node_time <= self.ts.max_root_time]) @@ -1400,7 +1576,7 @@ def assign_y_coordinates( min_time = min(self.node_height.values()) # don't need to check mutation times, as they must be above a node elif min_time == "ts": - min_time = np.min(self.ts.tables.nodes.time[referenced_nodes(self.ts)]) + min_time = np.min(self.ts.nodes_time[edge_and_sample_nodes(self.ts)]) # In pathological cases, all the nodes are at the same time if min_time == max_time: max_time = min_time + 1 @@ -1967,7 +2143,6 @@ def _draw(self): self.canvas[y, xv] = mid_char self.canvas[y, left] = left_child self.canvas[y, right] = right_child - # print(self.canvas) if self.orientation == TOP: self.canvas = np.flip(self.canvas, axis=0) # Reverse the time positions so that we can use them in the tree @@ -2070,4 +2245,3 @@ def _draw(self): # Move the padding to the left. self.canvas[:, :-1] = self.canvas[:, 1:] self.canvas[:, -1] = " " - # print(self.canvas) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 23cf681b4b..5311c4fcc3 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6958,6 +6958,7 @@ def draw_svg( y_gridlines=None, omit_sites=None, canvas_size=None, + max_num_trees=None, **kwargs, ): """ @@ -7044,6 +7045,11 @@ def draw_svg( elements, allowing extra room e.g. for unusually long labels. If ``None`` take the canvas size to be the same as the target drawing size (see ``size``, above). Default: None + :param int max_num_trees: The maximum number of trees to plot. If there are + more trees than this in the tree sequence, the middle trees will be skipped + from the plot and a message "XX trees skipped" displayed in their place. + If ``None``, all the trees will be plotted: this can produce a very wide + plot if there are many trees in the tree sequence. Default: None :return: An SVG representation of a tree sequence. :rtype: SVGString @@ -7081,6 +7087,7 @@ def draw_svg( y_gridlines=y_gridlines, omit_sites=omit_sites, canvas_size=canvas_size, + max_num_trees=max_num_trees, **kwargs, ) output = draw.drawing.tostring() From 816e4aafdb9871af33771081bca7c2f1de68dab4 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Mon, 28 Nov 2022 15:16:17 +0000 Subject: [PATCH 03/84] Update all the plots to the new scheme Note that some values are being listed as e.g. `-2.84217e-14` rather than `0`, so the `rnd` function could probably do with tweaking for the case of zero. --- python/CHANGELOG.rst | 4 ++++ python/tests/data/svg/internal_sample_ts.svg | 4 ++-- python/tests/data/svg/tree.svg | 2 +- python/tests/data/svg/tree_both_axes.svg | 6 +++--- python/tests/data/svg/tree_muts.svg | 2 +- python/tests/data/svg/tree_muts_all_edge.svg | 4 ++-- python/tests/data/svg/tree_timed_muts.svg | 2 +- python/tests/data/svg/tree_x_axis.svg | 4 ++-- python/tests/data/svg/tree_y_axis_rank.svg | 4 ++-- python/tests/data/svg/ts.svg | 4 ++-- python/tests/data/svg/ts_max_trees.svg | 6 +++--- python/tests/data/svg/ts_multiroot.svg | 10 +++++----- python/tests/data/svg/ts_mut_highlight.svg | 4 ++-- python/tests/data/svg/ts_mut_times.svg | 4 ++-- python/tests/data/svg/ts_mut_times_logscale.svg | 4 ++-- python/tests/data/svg/ts_mutations_no_edges.svg | 4 ++-- python/tests/data/svg/ts_mutations_timed_no_edges.svg | 4 ++-- python/tests/data/svg/ts_no_axes.svg | 2 +- python/tests/data/svg/ts_plain.svg | 4 ++-- python/tests/data/svg/ts_plain_no_xlab.svg | 4 ++-- python/tests/data/svg/ts_plain_y.svg | 6 +++--- python/tests/data/svg/ts_rank.svg | 6 +++--- python/tests/data/svg/ts_x_lim.svg | 4 ++-- python/tests/data/svg/ts_xlabel.svg | 4 ++-- python/tests/data/svg/ts_y_axis.svg | 6 +++--- python/tests/data/svg/ts_y_axis_log.svg | 6 +++--- python/tests/data/svg/ts_y_axis_regular.svg | 6 +++--- python/tests/test_drawing.py | 6 ++++++ python/tskit/drawing.py | 2 +- 29 files changed, 69 insertions(+), 59 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 268680a378..0b3431c467 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -8,6 +8,10 @@ which are the minimum and maximum among the node times and mutation times, respectively. (:user:`szhan`, :pr:`2612`, :issue:`2271`) +- The ``draw_svg`` methods now have a ``max_num_trees`` parameter to truncate + the total number of trees shown, giving a readable display for tree + sequences with many trees (:user:`hyanwong`, :pr:`2652`) + - The ``draw_svg`` methods now accept a ``canvas_size`` parameter to allow extra room on the canvas e.g. for long labels or repositioned graphical elements (:user:`hyanwong`, :pr:`2646`, :issue:`2645`) diff --git a/python/tests/data/svg/internal_sample_ts.svg b/python/tests/data/svg/internal_sample_ts.svg index 8527c34b9d..42d55392a0 100644 --- a/python/tests/data/svg/internal_sample_ts.svg +++ b/python/tests/data/svg/internal_sample_ts.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/tree.svg b/python/tests/data/svg/tree.svg index 5fb67266a4..0ab913202c 100644 --- a/python/tests/data/svg/tree.svg +++ b/python/tests/data/svg/tree.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/tree_both_axes.svg b/python/tests/data/svg/tree_both_axes.svg index 1cc2c584d2..a86dd3d379 100644 --- a/python/tests/data/svg/tree_both_axes.svg +++ b/python/tests/data/svg/tree_both_axes.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Genome position - + @@ -29,7 +29,7 @@ Time - + diff --git a/python/tests/data/svg/tree_muts.svg b/python/tests/data/svg/tree_muts.svg index 09abff1e02..3d2c017317 100644 --- a/python/tests/data/svg/tree_muts.svg +++ b/python/tests/data/svg/tree_muts.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/tree_muts_all_edge.svg b/python/tests/data/svg/tree_muts_all_edge.svg index 664e649ad6..adf4eb1a09 100644 --- a/python/tests/data/svg/tree_muts_all_edge.svg +++ b/python/tests/data/svg/tree_muts_all_edge.svg @@ -1,7 +1,7 @@ - + @@ -12,7 +12,7 @@ Genome position - + diff --git a/python/tests/data/svg/tree_timed_muts.svg b/python/tests/data/svg/tree_timed_muts.svg index 3efd7f32e2..0b79065c61 100644 --- a/python/tests/data/svg/tree_timed_muts.svg +++ b/python/tests/data/svg/tree_timed_muts.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/tree_x_axis.svg b/python/tests/data/svg/tree_x_axis.svg index be63748d24..e6c7af5cd6 100644 --- a/python/tests/data/svg/tree_x_axis.svg +++ b/python/tests/data/svg/tree_x_axis.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ pos on genome - + diff --git a/python/tests/data/svg/tree_y_axis_rank.svg b/python/tests/data/svg/tree_y_axis_rank.svg index 413b99c6db..9373285104 100644 --- a/python/tests/data/svg/tree_y_axis_rank.svg +++ b/python/tests/data/svg/tree_y_axis_rank.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Time (relative steps) - + diff --git a/python/tests/data/svg/ts.svg b/python/tests/data/svg/ts.svg index 63d68cb5ee..d413b0a08b 100644 --- a/python/tests/data/svg/ts.svg +++ b/python/tests/data/svg/ts.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_max_trees.svg b/python/tests/data/svg/ts_max_trees.svg index fff638c00e..3f3578de88 100644 --- a/python/tests/data/svg/ts_max_trees.svg +++ b/python/tests/data/svg/ts_max_trees.svg @@ -5,12 +5,12 @@ - + - - + + diff --git a/python/tests/data/svg/ts_multiroot.svg b/python/tests/data/svg/ts_multiroot.svg index ad61d5ec80..28dba3aa4e 100644 --- a/python/tests/data/svg/ts_multiroot.svg +++ b/python/tests/data/svg/ts_multiroot.svg @@ -1,12 +1,12 @@ - + - - + + @@ -19,7 +19,7 @@ Genome position - + @@ -141,7 +141,7 @@ Time (generations) - + diff --git a/python/tests/data/svg/ts_mut_highlight.svg b/python/tests/data/svg/ts_mut_highlight.svg index e8404ef4f7..0f7276d245 100644 --- a/python/tests/data/svg/ts_mut_highlight.svg +++ b/python/tests/data/svg/ts_mut_highlight.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mut_times.svg b/python/tests/data/svg/ts_mut_times.svg index 2ba161bb41..3bd6fb5ef3 100644 --- a/python/tests/data/svg/ts_mut_times.svg +++ b/python/tests/data/svg/ts_mut_times.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mut_times_logscale.svg b/python/tests/data/svg/ts_mut_times_logscale.svg index 669d4d97f6..86382d3cf8 100644 --- a/python/tests/data/svg/ts_mut_times_logscale.svg +++ b/python/tests/data/svg/ts_mut_times_logscale.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mutations_no_edges.svg b/python/tests/data/svg/ts_mutations_no_edges.svg index 547d0cc75f..4feb1e2a7a 100644 --- a/python/tests/data/svg/ts_mutations_no_edges.svg +++ b/python/tests/data/svg/ts_mutations_no_edges.svg @@ -1,7 +1,7 @@ - + @@ -12,7 +12,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mutations_timed_no_edges.svg b/python/tests/data/svg/ts_mutations_timed_no_edges.svg index 37c5dc1fe6..de064ffc36 100644 --- a/python/tests/data/svg/ts_mutations_timed_no_edges.svg +++ b/python/tests/data/svg/ts_mutations_timed_no_edges.svg @@ -1,7 +1,7 @@ - + @@ -12,7 +12,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_no_axes.svg b/python/tests/data/svg/ts_no_axes.svg index 1e3c2ff479..051cbb1fb0 100644 --- a/python/tests/data/svg/ts_no_axes.svg +++ b/python/tests/data/svg/ts_no_axes.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/ts_plain.svg b/python/tests/data/svg/ts_plain.svg index f0586a9dae..6bb71f35a8 100644 --- a/python/tests/data/svg/ts_plain.svg +++ b/python/tests/data/svg/ts_plain.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_plain_no_xlab.svg b/python/tests/data/svg/ts_plain_no_xlab.svg index fdf2bf4618..648a306187 100644 --- a/python/tests/data/svg/ts_plain_no_xlab.svg +++ b/python/tests/data/svg/ts_plain_no_xlab.svg @@ -1,12 +1,12 @@ - + - + diff --git a/python/tests/data/svg/ts_plain_y.svg b/python/tests/data/svg/ts_plain_y.svg index beac22d28b..9e4499d3f7 100644 --- a/python/tests/data/svg/ts_plain_y.svg +++ b/python/tests/data/svg/ts_plain_y.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Genome position - + @@ -53,7 +53,7 @@ Time - + diff --git a/python/tests/data/svg/ts_rank.svg b/python/tests/data/svg/ts_rank.svg index 3455d28316..b8527b6638 100644 --- a/python/tests/data/svg/ts_rank.svg +++ b/python/tests/data/svg/ts_rank.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Node time - + diff --git a/python/tests/data/svg/ts_x_lim.svg b/python/tests/data/svg/ts_x_lim.svg index ce80434d50..e0aef7f41d 100644 --- a/python/tests/data/svg/ts_x_lim.svg +++ b/python/tests/data/svg/ts_x_lim.svg @@ -1,7 +1,7 @@ - + @@ -14,7 +14,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_xlabel.svg b/python/tests/data/svg/ts_xlabel.svg index da4a0b8b85..8af7c9dd36 100644 --- a/python/tests/data/svg/ts_xlabel.svg +++ b/python/tests/data/svg/ts_xlabel.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ genomic position (bp) - + diff --git a/python/tests/data/svg/ts_y_axis.svg b/python/tests/data/svg/ts_y_axis.svg index 202f5d4a65..dccc399050 100644 --- a/python/tests/data/svg/ts_y_axis.svg +++ b/python/tests/data/svg/ts_y_axis.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Time (generations) - + diff --git a/python/tests/data/svg/ts_y_axis_log.svg b/python/tests/data/svg/ts_y_axis_log.svg index ac0051336f..70afefd41f 100644 --- a/python/tests/data/svg/ts_y_axis_log.svg +++ b/python/tests/data/svg/ts_y_axis_log.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Time (log scale) - + diff --git a/python/tests/data/svg/ts_y_axis_regular.svg b/python/tests/data/svg/ts_y_axis_regular.svg index d2d866e51f..f5e4240cde 100644 --- a/python/tests/data/svg/ts_y_axis_regular.svg +++ b/python/tests/data/svg/ts_y_axis_regular.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Time - + diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 97dbe77bd5..e95ed10988 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -29,6 +29,7 @@ import math import os import pathlib +import platform import re import xml.dom.minidom import xml.etree @@ -44,6 +45,9 @@ from tskit import drawing +IS_WINDOWS = platform.system() == "Windows" + + class TestTreeDraw: """ Tests for the tree drawing functionality. @@ -2797,6 +2801,7 @@ def test_known_svg_ts_xlim(self, overwrite_viz, draw_plotbox, caplog): num_trees = sum(1 for b in ts.breakpoints() if 0.051 <= b < 0.9) + 1 self.verify_known_svg(svg, "ts_x_lim.svg", overwrite_viz, width=200 * num_trees) + @pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows") def test_known_max_num_trees(self, overwrite_viz, draw_plotbox, caplog): max_trees = 5 ts = msprime.sim_ancestry( @@ -2816,6 +2821,7 @@ def test_known_max_num_trees(self, overwrite_viz, draw_plotbox, caplog): svg, "ts_max_trees.svg", overwrite_viz, width=200 * (max_trees + 1) ) + @pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows") def test_known_max_num_trees_treewise(self, overwrite_viz, draw_plotbox, caplog): max_trees = 5 ts = msprime.sim_ancestry( diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index 3c6ad5beae..51e0260cd6 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -965,7 +965,7 @@ def shade_background( diag_h=rnd(diag_height), tick_h=rnd(tick_length_lower), ax_x=rnd(prev_break_x - break_x), - ldiag_x=rnd(prev_tree_x - prev_break_x), + ldiag_x=rnd(rnd(prev_tree_x) - rnd(prev_break_x)), ) ) ) From cba692251592364865392adfd27479060e57054f Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 6 Jan 2023 17:23:00 +0000 Subject: [PATCH 04/84] Avoid deprecation warnings Curewntly getting "FutureWarning: This property is a deprecated alias for Tree.tree_sequence.num_nodes and will be removed in the future" --- python/tests/__init__.py | 4 ++-- python/tests/test_genotype_matching_fb.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tests/__init__.py b/python/tests/__init__.py index 1f064b6a8d..f069f04f2e 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -53,7 +53,7 @@ def __init__(self, num_nodes): @classmethod def from_tree(cls, tree): - ret = PythonTree(tree.num_nodes) + ret = PythonTree(tree.tree_sequence.num_nodes) ret.left, ret.right = tree.get_interval() ret.site_list = list(tree.sites()) ret.index = tree.get_index() diff --git a/python/tests/test_genotype_matching_fb.py b/python/tests/test_genotype_matching_fb.py index 88dd7a754d..984b3ce13a 100644 --- a/python/tests/test_genotype_matching_fb.py +++ b/python/tests/test_genotype_matching_fb.py @@ -174,8 +174,8 @@ def stupid_compress_dict(self): # Retain the old T_index, because the internal T that's passed up the tree will # retain this ordering. old_T_index = copy.deepcopy(self.T_index) - self.T_index = np.zeros(tree.num_nodes, dtype=int) - 1 - self.N = np.zeros(tree.num_nodes, dtype=int) + self.T_index = np.zeros(tree.tree_sequence.num_nodes, dtype=int) - 1 + self.N = np.zeros(tree.tree_sequence.num_nodes, dtype=int) self.T.clear() # First, create T root. @@ -345,7 +345,7 @@ def update_tree(self): vt.tree_node = -1 vt.value_index = -1 - self.N = np.zeros(self.tree.num_nodes, dtype=int) + self.N = np.zeros(self.tree.tree_sequence.num_nodes, dtype=int) node_map = {st.tree_node: st for st in self.T} for u in self.tree.samples(): From 4d8b341583badfcd0b4ecd9805b014b7803d0997 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 6 Jan 2023 17:06:30 +0000 Subject: [PATCH 05/84] Ban negative root_thresholds --- python/tests/test_highlevel.py | 18 +++++++++++++++++- python/tskit/trees.py | 4 +++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index b124ea6fca..841ccc1e09 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -3867,6 +3867,22 @@ def test_branch_length_empty_tree(self): assert tree.branch_length(1) == 0 assert tree.total_branch_length == 0 + @pytest.mark.parametrize("r_threshold", [0, -1]) + def test_bad_val_root_threshold(self, r_threshold): + with pytest.raises(ValueError, match="greater than 0"): + tskit.Tree.generate_balanced(2, root_threshold=r_threshold) + + @pytest.mark.parametrize("r_threshold", [None, 0.5, 1.5, np.inf]) + def test_bad_type_root_threshold(self, r_threshold): + with pytest.raises(TypeError): + tskit.Tree.generate_balanced(2, root_threshold=r_threshold) + + def test_simple_root_threshold(self): + tree = tskit.Tree.generate_balanced(3, root_threshold=3) + assert tree.num_roots == 1 + tree = tskit.Tree.generate_balanced(3, root_threshold=4) + assert tree.num_roots == 0 + def test_is_descendant(self): def is_descendant(tree, u, v): path = [] diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 5311c4fcc3..49552545c0 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -665,6 +665,8 @@ def __init__( if sample_lists: options |= _tskit.SAMPLE_LISTS kwargs = {"options": options} + if root_threshold <= 0: + raise ValueError("Root threshold must be greater than 0") if tracked_samples is not None: # TODO remove this when we allow numpy arrays in the low-level API. kwargs["tracked_samples"] = list(tracked_samples) From d4efac9f1ab8a0871c873cdb12ab7473485769dc Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 10 Jan 2023 10:59:11 +0000 Subject: [PATCH 06/84] Make rust package manager use git cli for fetching --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 0b2e2056ca..7066012b29 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -12,7 +12,7 @@ commands: sudo pip install meson pip install numpy==1.18.5 pip install --user -r python/requirements/CI-complete/requirements.txt - pip install twine --user + ARGO_NET_GIT_FETCH_WITH_CLI=1 pip install twine --user # Remove tskit installed by msprime pip uninstall tskit -y echo 'export PATH=/home/circleci/.local/bin:$PATH' >> $BASH_ENV From 2cc38bf1792518530c80bf5286978634359395c3 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Mon, 17 Oct 2022 20:17:55 +0100 Subject: [PATCH 07/84] Add TSK_SIMPLIFY_NO_FILTER_NODES --- c/tskit/tables.c | 63 +++-- c/tskit/tables.h | 5 + python/_tskitmodule.c | 12 +- python/tests/conftest.py | 8 + python/tests/simplify.py | 11 + python/tests/test_topology.py | 424 ++++++++++++++++++++++++---------- python/tests/tsutil.py | 39 ++-- python/tskit/tables.py | 15 +- python/tskit/trees.py | 32 ++- 9 files changed, 439 insertions(+), 170 deletions(-) diff --git a/c/tskit/tables.c b/c/tskit/tables.c index bb199bf200..43fd828d2a 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -8697,6 +8697,8 @@ simplifier_print_state(simplifier_t *self, FILE *out) fprintf(out, "options:\n"); fprintf(out, "\tfilter_unreferenced_sites : %d\n", !!(self->options & TSK_SIMPLIFY_FILTER_SITES)); + fprintf(out, "\tno_filter_nodes : %d\n", + !!(self->options & TSK_SIMPLIFY_NO_FILTER_NODES)); fprintf(out, "\treduce_to_site_topology : %d\n", !!(self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY)); fprintf(out, "\tkeep_unary : %d\n", @@ -9059,16 +9061,17 @@ simplifier_add_ancestry( } static int -simplifier_init_samples(simplifier_t *self, const tsk_id_t *samples) +simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples) { int ret = 0; tsk_id_t node_id; tsk_size_t j; + tsk_size_t num_nodes = self->input_tables.nodes.num_rows; + bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); + bool is_sample; - /* Go through the samples to check for errors. */ for (j = 0; j < self->num_samples; j++) { - if (samples[j] < 0 - || samples[j] > (tsk_id_t) self->input_tables.nodes.num_rows) { + if (samples[j] < 0 || samples[j] > (tsk_id_t) num_nodes) { ret = TSK_ERR_NODE_OUT_OF_BOUNDS; goto out; } @@ -9077,15 +9080,38 @@ simplifier_init_samples(simplifier_t *self, const tsk_id_t *samples) goto out; } self->is_sample[samples[j]] = true; - node_id = simplifier_record_node(self, samples[j], true); - if (node_id < 0) { - ret = (int) node_id; - goto out; + } + + if (filter_nodes) { + /* Go through the samples to check for errors. */ + for (j = 0; j < self->num_samples; j++) { + node_id = simplifier_record_node(self, samples[j], true); + if (node_id < 0) { + ret = (int) node_id; + goto out; + } + ret = simplifier_add_ancestry( + self, samples[j], 0, self->tables->sequence_length, node_id); + if (ret != 0) { + goto out; + } } - ret = simplifier_add_ancestry( - self, samples[j], 0, self->tables->sequence_length, node_id); - if (ret != 0) { - goto out; + } else { + /* record all the nodes, but only save ancestry for those in the sample */ + for (j = 0; j < num_nodes; j++) { + is_sample = self->is_sample[j]; + node_id = simplifier_record_node(self, (tsk_id_t) j, is_sample); + if (node_id < 0) { + ret = (int) node_id; + goto out; + } + if (is_sample) { + ret = simplifier_add_ancestry( + self, node_id, 0, self->tables->sequence_length, node_id); + if (ret != 0) { + goto out; + } + } } } out: @@ -9174,7 +9200,7 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp if (ret != 0) { goto out; } - ret = simplifier_init_samples(self, samples); + ret = simplifier_init_nodes(self, samples); if (ret != 0) { goto out; } @@ -9253,11 +9279,10 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) tsk_id_t ancestry_node; tsk_id_t output_id = self->node_id_map[input_id]; - bool is_sample = output_id != TSK_NULL; - bool keep_unary = false; - if (self->options & TSK_SIMPLIFY_KEEP_UNARY) { - keep_unary = true; - } + bool is_sample = self->is_sample[input_id]; + /* bool is_sample = output_id != TSK_NULL; */ + bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); + bool keep_unary = !!(self->options & TSK_SIMPLIFY_KEEP_UNARY); if ((self->options & TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS) && (self->input_tables.nodes.individual[input_id] != TSK_NULL)) { keep_unary = true; @@ -9348,7 +9373,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) if (ret != 0) { goto out; } - if (num_flushed_edges == 0 && !is_sample) { + if (filter_nodes && (num_flushed_edges == 0) && !is_sample) { ret = simplifier_rewind_node(self, input_id, output_id); } } diff --git a/c/tskit/tables.h b/c/tskit/tables.h index bab3546220..8fc0bc4cfb 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -715,6 +715,10 @@ flag). It keeps unary nodes, but only if the unary node is referenced from an in @endrst */ #define TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS (1 << 6) +/** Retain nodes in the output even if no edges reference them. This is negated +compared to the other TSK_SIMPLIFY_FILTER_XXX flags to preserve previous behaviour. +*/ +#define TSK_SIMPLIFY_NO_FILTER_NODES (1 << 7) /** @} */ /** @@ -3928,6 +3932,7 @@ Options can be specified by providing one or more of the following bitwise - :c:macro:`TSK_SIMPLIFY_KEEP_UNARY` - :c:macro:`TSK_SIMPLIFY_KEEP_INPUT_ROOTS` - :c:macro:`TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS` +- :c:macro:`TSK_SIMPLIFY_NO_FILTER_NODES` @endrst @param self A pointer to a tsk_table_collection_t object. diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 22f78c1244..6d5248b209 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6585,22 +6585,23 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) npy_intp *shape, dims; tsk_size_t num_samples; tsk_flags_t options = 0; - int filter_sites = true; + int filter_sites = false; int filter_individuals = false; int filter_populations = false; + int filter_nodes = true; int keep_unary = false; int keep_unary_in_individuals = false; int keep_input_roots = false; int reduce_to_site_topology = false; static char *kwlist[] = { "samples", "filter_sites", "filter_populations", - "filter_individuals", "reduce_to_site_topology", "keep_unary", + "filter_individuals", "filter_nodes", "reduce_to_site_topology", "keep_unary", "keep_unary_in_individuals", "keep_input_roots", NULL }; if (TableCollection_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiii", kwlist, &samples, - &filter_sites, &filter_populations, &filter_individuals, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiii", kwlist, &samples, + &filter_sites, &filter_populations, &filter_individuals, &filter_nodes, &reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals, &keep_input_roots)) { goto out; @@ -6621,6 +6622,9 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) if (filter_populations) { options |= TSK_SIMPLIFY_FILTER_POPULATIONS; } + if (!filter_nodes) { + options |= TSK_SIMPLIFY_NO_FILTER_NODES; + } if (reduce_to_site_topology) { options |= TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY; } diff --git a/python/tests/conftest.py b/python/tests/conftest.py index d2539ed0fb..d23c019003 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -111,6 +111,14 @@ def ts_fixture(): return tsutil.all_fields_ts() +@fixture(scope="session") +def ts_fixture_for_simplify(): + """ + A tree sequence with data in all fields execpt edge metadata and migrations + """ + return tsutil.all_fields_ts(edge_metadata=False, migrations=False) + + @fixture(scope="session") def replicate_ts_fixture(): """ diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 5f7e838e6c..e27604bf2c 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -111,6 +111,7 @@ def __init__( keep_unary=False, keep_unary_in_individuals=False, keep_input_roots=False, + filter_nodes=True, # If this is False, the order in `sample` is ignored ): self.ts = ts self.n = len(sample) @@ -119,6 +120,7 @@ def __init__( self.filter_sites = filter_sites self.filter_populations = filter_populations self.filter_individuals = filter_individuals + self.filter_nodes = filter_nodes self.keep_unary = keep_unary self.keep_unary_in_individuals = keep_unary_in_individuals self.keep_input_roots = keep_input_roots @@ -128,6 +130,11 @@ def __init__( self.A_tail = [None for _ in range(ts.num_nodes)] self.tables = self.ts.tables.copy() self.tables.clear() + if not filter_nodes: + # NOTE: this is hack-ish. + # So far, we have copied the tables once, + # cleared them, and then re-copied. + self.tables = self.ts.tables.copy() self.edge_buffer = {} self.node_id_map = np.zeros(ts.num_nodes, dtype=np.int32) - 1 self.mutation_node_map = [-1 for _ in range(self.num_mutations)] @@ -144,6 +151,7 @@ def __init__( for sample_id in sample: output_id = self.record_node(sample_id, is_sample=True) self.add_ancestry(sample_id, 0, self.sequence_length, output_id) + self.position_lookup = None if self.reduce_to_site_topology: self.position_lookup = np.hstack([[0], position, [self.sequence_length]]) @@ -159,6 +167,9 @@ def record_node(self, input_id, is_sample=False): flags &= ~tskit.NODE_IS_SAMPLE if is_sample: flags |= tskit.NODE_IS_SAMPLE + if not self.filter_nodes: + self.node_id_map[input_id] = input_id + return input_id output_id = self.tables.nodes.append(node.replace(flags=flags)) self.node_id_map[input_id] = output_id return output_id diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index bb2354dd63..ac1368e921 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -2641,8 +2641,10 @@ class TestSimplifyExamples(TopologyTestCase): def verify_simplify( self, samples, + *, filter_sites=True, keep_input_roots=False, + filter_nodes=True, nodes_before=None, edges_before=None, sites_before=None, @@ -2657,7 +2659,7 @@ def verify_simplify( Verifies that if we run simplify on the specified input we get the required output. """ - ts = tskit.load_text( + before = tskit.load_text( nodes=io.StringIO(nodes_before), edges=io.StringIO(edges_before), sites=io.StringIO(sites_before) if sites_before is not None else None, @@ -2666,9 +2668,8 @@ def verify_simplify( ), strict=False, ) - before = ts.dump_tables() - ts = tskit.load_text( + after = tskit.load_text( nodes=io.StringIO(nodes_after), edges=io.StringIO(edges_after), sites=io.StringIO(sites_after) if sites_after is not None else None, @@ -2678,23 +2679,26 @@ def verify_simplify( strict=False, sequence_length=before.sequence_length, ) - after = ts.dump_tables() - # Make sure it's a valid tree sequence - ts = before.tree_sequence() - before.simplify( + + result, _ = do_simplify( + before, samples=samples, filter_sites=filter_sites, keep_input_roots=keep_input_roots, - record_provenance=False, + filter_nodes=filter_nodes, + compare_lib=False, # TMP ) if debug: print("before") print(before) - print(before.tree_sequence().draw_text()) + print(before.draw_text()) print("after") print(after) - print(after.tree_sequence().draw_text()) - assert before == after + print(after.draw_text()) + print("result") + print(result) + print(result.draw_text()) + after.tables.assert_equals(result.tables) def test_unsorted_edges(self): # We have two nodes at the same time and interleave edges for @@ -3250,6 +3254,38 @@ def test_unary_edges_no_overlap_internal_sample(self): edges_after=edges_before, ) + def test_keep_nodes(self): + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 2 + 4 0 3 + """ + edges_before = """\ + left right parent child + 0 1 2 0 + 0 1 2 1 + 0 1 3 2 + 0 1 4 3 + """ + edges_after = """\ + left right parent child + 0 1 2 0 + 0 1 2 1 + 0 1 4 2 + """ + self.verify_simplify( + samples=[0, 1], + nodes_before=nodes_before, + edges_before=edges_before, + nodes_after=nodes_before, + edges_after=edges_after, + filter_nodes=False, + keep_input_roots=True, + ) + class TestNonSampleExternalNodes(TopologyTestCase): """ @@ -4711,74 +4747,77 @@ def test_kwargs(self): assert t1.num_tracked_samples() == t2.num_tracked_samples() == 4 -class SimplifyTestBase: +def do_simplify( + ts, + samples=None, + compare_lib=True, + filter_sites=True, + filter_populations=True, + filter_individuals=True, + filter_nodes=True, + keep_unary=False, + keep_input_roots=False, +): """ - Base class for simplify tests. + Runs the Python test implementation of simplify. """ - - def do_simplify( - self, + if samples is None: + samples = ts.samples() + s = tests.Simplifier( ts, - samples=None, - compare_lib=True, - filter_sites=True, - filter_populations=True, - filter_individuals=True, - keep_unary=False, - keep_input_roots=False, - ): - """ - Runs the Python test implementation of simplify. - """ - if samples is None: - samples = ts.samples() - s = tests.Simplifier( - ts, + samples, + filter_sites=filter_sites, + filter_populations=filter_populations, + filter_individuals=filter_individuals, + filter_nodes=filter_nodes, + keep_unary=keep_unary, + keep_input_roots=keep_input_roots, + ) + new_ts, node_map = s.simplify() + if compare_lib: + sts, lib_node_map1 = ts.simplify( samples, filter_sites=filter_sites, - filter_populations=filter_populations, filter_individuals=filter_individuals, + filter_populations=filter_populations, + filter_nodes=filter_nodes, keep_unary=keep_unary, keep_input_roots=keep_input_roots, + map_nodes=True, ) - new_ts, node_map = s.simplify() - if compare_lib: - sts, lib_node_map1 = ts.simplify( - samples, - filter_sites=filter_sites, - filter_individuals=filter_individuals, - filter_populations=filter_populations, - keep_unary=keep_unary, - keep_input_roots=keep_input_roots, - map_nodes=True, - ) - lib_tables1 = sts.dump_tables() - - lib_tables2 = ts.dump_tables() - lib_node_map2 = lib_tables2.simplify( - samples, - filter_sites=filter_sites, - keep_unary=keep_unary, - keep_input_roots=keep_input_roots, - filter_individuals=filter_individuals, - filter_populations=filter_populations, - ) + lib_tables1 = sts.dump_tables() - py_tables = new_ts.dump_tables() - for lib_tables, lib_node_map in [ - (lib_tables1, lib_node_map1), - (lib_tables2, lib_node_map2), - ]: + lib_tables2 = ts.dump_tables() + lib_node_map2 = lib_tables2.simplify( + samples, + filter_sites=filter_sites, + keep_unary=keep_unary, + keep_input_roots=keep_input_roots, + filter_individuals=filter_individuals, + filter_populations=filter_populations, + filter_nodes=filter_nodes, + ) + + py_tables = new_ts.dump_tables() + for lib_tables, lib_node_map in [ + (lib_tables1, lib_node_map1), + (lib_tables2, lib_node_map2), + ]: + assert lib_tables.nodes == py_tables.nodes + assert lib_tables.edges == py_tables.edges + assert lib_tables.migrations == py_tables.migrations + assert lib_tables.sites == py_tables.sites + assert lib_tables.mutations == py_tables.mutations + assert lib_tables.individuals == py_tables.individuals + assert lib_tables.populations == py_tables.populations + assert all(node_map == lib_node_map) + return new_ts, node_map - assert lib_tables.nodes == py_tables.nodes - assert lib_tables.edges == py_tables.edges - assert lib_tables.migrations == py_tables.migrations - assert lib_tables.sites == py_tables.sites - assert lib_tables.mutations == py_tables.mutations - assert lib_tables.individuals == py_tables.individuals - assert lib_tables.populations == py_tables.populations - assert all(node_map == lib_node_map) - return new_ts, node_map + +class SimplifyTestBase: + """ + Base class for simplify tests. + """ class TestSimplify(SimplifyTestBase): @@ -4824,11 +4863,9 @@ def verify_no_samples(self, ts, keep_unary=False): """ t1 = ts.dump_tables() t1.nodes.flags = np.zeros_like(t1.nodes.flags) - ts1, node_map1 = self.do_simplify( - ts, samples=ts.samples(), keep_unary=keep_unary - ) + ts1, node_map1 = do_simplify(ts, samples=ts.samples(), keep_unary=keep_unary) t1 = ts1.dump_tables() - ts2, node_map2 = self.do_simplify(ts, keep_unary=keep_unary) + ts2, node_map2 = do_simplify(ts, keep_unary=keep_unary) t2 = ts2.dump_tables() t1.assert_equals(t2) @@ -4841,7 +4878,7 @@ def verify_single_childified(self, ts, keep_unary=False): """ ts_single = tsutil.single_childify(ts) - tss, node_map = self.do_simplify(ts_single, keep_unary=keep_unary) + tss, node_map = do_simplify(ts_single, keep_unary=keep_unary) # All original nodes should still be present. for u in range(ts.num_samples): assert u == node_map[u] @@ -4865,7 +4902,7 @@ def verify_single_childified(self, ts, keep_unary=False): def verify_multiroot_internal_samples(self, ts, keep_unary=False): ts_multiroot = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts1 = tsutil.jiggle_samples(ts_multiroot) - ts2, node_map = self.do_simplify(ts1, keep_unary=keep_unary) + ts2, node_map = do_simplify(ts1, keep_unary=keep_unary) assert ts1.num_trees >= ts2.num_trees trees2 = ts2.trees() t2 = next(trees2) @@ -4897,10 +4934,10 @@ def test_single_tree(self): def test_single_tree_mutations(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=self.random_seed) assert ts.num_sites > 1 - self.do_simplify(ts) + do_simplify(ts) self.verify_single_childified(ts) # Also with keep_unary == True. - self.do_simplify(ts, keep_unary=True) + do_simplify(ts, keep_unary=True) self.verify_single_childified(ts, keep_unary=True) def test_many_trees_mutations(self): @@ -4910,10 +4947,10 @@ def test_many_trees_mutations(self): assert ts.num_trees > 2 assert ts.num_sites > 2 self.verify_no_samples(ts) - self.do_simplify(ts) + do_simplify(ts) self.verify_single_childified(ts) # Also with keep_unary == True. - self.do_simplify(ts, keep_unary=True) + do_simplify(ts, keep_unary=True) self.verify_single_childified(ts, keep_unary=True) def test_many_trees(self): @@ -4944,14 +4981,14 @@ def test_small_tree_internal_samples(self): nodes.flags = flags ts = tables.tree_sequence() assert ts.sample_size == 5 - tss, node_map = self.do_simplify(ts, [3, 5]) + tss, node_map = do_simplify(ts, [3, 5]) assert node_map[3] == 0 assert node_map[5] == 1 assert tss.num_nodes == 3 assert tss.num_edges == 2 self.verify_no_samples(ts) # with keep_unary == True - tss, node_map = self.do_simplify(ts, [3, 5], keep_unary=True) + tss, node_map = do_simplify(ts, [3, 5], keep_unary=True) assert node_map[3] == 0 assert node_map[5] == 1 assert tss.num_nodes == 5 @@ -4974,7 +5011,7 @@ def test_small_tree_linear_samples(self): nodes.flags = flags ts = tables.tree_sequence() assert ts.sample_size == 2 - tss, node_map = self.do_simplify(ts, [0, 7]) + tss, node_map = do_simplify(ts, [0, 7]) assert node_map[0] == 0 assert node_map[7] == 1 assert tss.num_nodes == 2 @@ -4982,7 +5019,7 @@ def test_small_tree_linear_samples(self): t = next(tss.trees()) assert t.parent_dict == {0: 1} # with keep_unary == True - tss, node_map = self.do_simplify(ts, [0, 7], keep_unary=True) + tss, node_map = do_simplify(ts, [0, 7], keep_unary=True) assert node_map[0] == 0 assert node_map[7] == 1 assert tss.num_nodes == 4 @@ -5006,7 +5043,7 @@ def test_small_tree_internal_and_external_samples(self): nodes.flags = flags ts = tables.tree_sequence() assert ts.sample_size == 3 - tss, node_map = self.do_simplify(ts, [0, 1, 7]) + tss, node_map = do_simplify(ts, [0, 1, 7]) assert node_map[0] == 0 assert node_map[1] == 1 assert node_map[7] == 2 @@ -5015,7 +5052,7 @@ def test_small_tree_internal_and_external_samples(self): t = next(tss.trees()) assert t.parent_dict == {0: 3, 1: 3, 3: 2} # with keep_unary == True - tss, node_map = self.do_simplify(ts, [0, 1, 7], keep_unary=True) + tss, node_map = do_simplify(ts, [0, 1, 7], keep_unary=True) assert node_map[0] == 0 assert node_map[1] == 1 assert node_map[7] == 2 @@ -5044,7 +5081,7 @@ def test_small_tree_mutations(self): assert ts.num_sites == 4 assert ts.num_mutations == 4 for keep in [True, False]: - tss = self.do_simplify(ts, [0, 2], keep_unary=keep)[0] + tss = do_simplify(ts, [0, 2], keep_unary=keep)[0] assert tss.sample_size == 2 assert tss.num_mutations == 4 assert list(tss.haplotypes()) == ["1011", "0100"] @@ -5059,12 +5096,10 @@ def test_small_tree_filter_zero_mutations(self): assert ts.num_sites == 8 assert ts.num_mutations == 8 for keep in [True, False]: - tss, _ = self.do_simplify(ts, [4, 0, 1], filter_sites=True, keep_unary=keep) + tss, _ = do_simplify(ts, [4, 0, 1], filter_sites=True, keep_unary=keep) assert tss.num_sites == 5 assert tss.num_mutations == 5 - tss, _ = self.do_simplify( - ts, [4, 0, 1], filter_sites=False, keep_unary=keep - ) + tss, _ = do_simplify(ts, [4, 0, 1], filter_sites=False, keep_unary=keep) assert tss.num_sites == 8 assert tss.num_mutations == 5 @@ -5086,7 +5121,7 @@ def test_small_tree_fixed_sites(self): assert ts.num_sites == 3 assert ts.num_mutations == 3 for keep in [True, False]: - tss, _ = self.do_simplify(ts, [4, 1], keep_unary=keep) + tss, _ = do_simplify(ts, [4, 1], keep_unary=keep) assert tss.sample_size == 2 assert tss.num_mutations == 0 assert list(tss.haplotypes()) == ["", ""] @@ -5104,7 +5139,7 @@ def test_small_tree_mutations_over_root(self): assert ts.num_sites == 1 assert ts.num_mutations == 1 for keep_unary, filter_sites in itertools.product([True, False], repeat=2): - tss, _ = self.do_simplify( + tss, _ = do_simplify( ts, [0, 1], filter_sites=filter_sites, keep_unary=keep_unary ) assert tss.num_sites == 1 @@ -5125,7 +5160,7 @@ def test_small_tree_recurrent_mutations(self): assert ts.num_sites == 1 assert ts.num_mutations == 2 for keep in [True, False]: - tss = self.do_simplify(ts, [4, 3], keep_unary=keep)[0] + tss = do_simplify(ts, [4, 3], keep_unary=keep)[0] assert tss.sample_size == 2 assert tss.num_sites == 1 assert tss.num_mutations == 2 @@ -5149,7 +5184,7 @@ def test_small_tree_back_mutations(self): assert list(ts.haplotypes()) == ["0", "1", "0", "0", "1"] # First check if we simplify for all samples and keep original state. for keep in [True, False]: - tss = self.do_simplify(ts, [0, 1, 2, 3, 4], keep_unary=keep)[0] + tss = do_simplify(ts, [0, 1, 2, 3, 4], keep_unary=keep)[0] assert tss.sample_size == 5 assert tss.num_sites == 1 assert tss.num_mutations == 3 @@ -5157,7 +5192,7 @@ def test_small_tree_back_mutations(self): # The ancestral state above 5 should be 0. for keep in [True, False]: - tss = self.do_simplify(ts, [0, 1], keep_unary=keep)[0] + tss = do_simplify(ts, [0, 1], keep_unary=keep)[0] assert tss.sample_size == 2 assert tss.num_sites == 1 assert tss.num_mutations == 3 @@ -5165,7 +5200,7 @@ def test_small_tree_back_mutations(self): # The ancestral state above 7 should be 1. for keep in [True, False]: - tss = self.do_simplify(ts, [4, 0, 1], keep_unary=keep)[0] + tss = do_simplify(ts, [4, 0, 1], keep_unary=keep)[0] assert tss.sample_size == 3 assert tss.num_sites == 1 assert tss.num_mutations == 3 @@ -5192,7 +5227,7 @@ def test_overlapping_unary_edges(self): assert ts.num_trees == 3 assert ts.sequence_length == 3 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, samples=[0, 1, 2], keep_unary=keep) + tss, node_map = do_simplify(ts, samples=[0, 1, 2], keep_unary=keep) assert list(node_map) == [0, 1, 2] trees = [{0: 2}, {0: 2, 1: 2}, {1: 2}] for t in tss.trees(): @@ -5220,7 +5255,7 @@ def test_overlapping_unary_edges_internal_samples(self): trees = [{0: 2}, {0: 2, 1: 2}, {1: 2}] for t in ts.trees(): assert t.parent_dict == trees[t.index] - tss, node_map = self.do_simplify(ts) + tss, node_map = do_simplify(ts) assert list(node_map) == [0, 1, 2] def test_isolated_samples(self): @@ -5242,7 +5277,7 @@ def test_isolated_samples(self): assert ts.num_trees == 1 assert ts.num_nodes == 3 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, keep_unary=keep) + tss, node_map = do_simplify(ts, keep_unary=keep) assert ts.tables.nodes == tss.tables.nodes assert ts.tables.edges == tss.tables.edges assert list(node_map) == [0, 1, 2] @@ -5275,7 +5310,7 @@ def test_internal_samples(self): ) ts = tskit.load_text(nodes, edges, strict=False) - tss, node_map = self.do_simplify(ts, [5, 2, 0]) + tss, node_map = do_simplify(ts, [5, 2, 0]) assert node_map[0] == 2 assert node_map[1] == -1 assert node_map[2] == 1 @@ -5290,7 +5325,7 @@ def test_internal_samples(self): for t in tss.trees(): assert t.parent_dict == trees[t.index] # with keep_unary == True - tss, node_map = self.do_simplify(ts, [5, 2, 0], keep_unary=True) + tss, node_map = do_simplify(ts, [5, 2, 0], keep_unary=True) assert node_map[0] == 2 assert node_map[1] == 4 assert node_map[2] == 1 @@ -5343,7 +5378,7 @@ def test_many_mutations_over_single_sample_ancestral_state(self): assert ts.num_sites == 1 assert ts.num_mutations == 2 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, keep_unary=keep) + tss, node_map = do_simplify(ts, keep_unary=keep) assert tss.num_sites == 1 assert tss.num_mutations == 2 assert list(tss.haplotypes(isolated_as_missing=False)) == ["0"] @@ -5384,7 +5419,7 @@ def test_many_mutations_over_single_sample_derived_state(self): assert ts.num_sites == 1 assert ts.num_mutations == 3 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, keep_unary=keep) + tss, node_map = do_simplify(ts, keep_unary=keep) assert tss.num_sites == 1 assert tss.num_mutations == 3 assert list(tss.haplotypes(isolated_as_missing=False)) == ["1"] @@ -5397,7 +5432,7 @@ def test_many_trees_filter_zero_mutations(self): assert ts.num_sites > ts.num_trees for keep in [True, False]: for filter_sites in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( ts, samples=None, filter_sites=filter_sites, keep_unary=keep ) assert ts.num_sites == tss.num_sites @@ -5411,7 +5446,7 @@ def test_many_trees_filter_zero_multichar_mutations(self): assert ts.num_mutations == ts.num_trees for keep in [True, False]: for filter_sites in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( ts, samples=None, filter_sites=filter_sites, keep_unary=keep ) assert ts.num_sites == tss.num_sites @@ -5423,11 +5458,11 @@ def test_simple_population_filter(self): tables.populations.add_row(metadata=b"unreferenced") assert len(tables.populations) == 2 for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=True, keep_unary=keep ) assert tss.num_populations == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=False, keep_unary=keep ) assert tss.num_populations == 2 @@ -5451,14 +5486,14 @@ def test_interleaved_populations_filter(self): ts = tables.tree_sequence() id_map = np.array([-1, 0, -1, -1], dtype=np.int32) for keep in [True, False]: - tss, _ = self.do_simplify(ts, filter_populations=True, keep_unary=keep) + tss, _ = do_simplify(ts, filter_populations=True, keep_unary=keep) assert tss.num_populations == 1 population = tss.population(0) assert population.metadata == bytes([1]) assert np.array_equal( id_map[ts.tables.nodes.population], tss.tables.nodes.population ) - tss, _ = self.do_simplify(ts, filter_populations=False, keep_unary=keep) + tss, _ = do_simplify(ts, filter_populations=False, keep_unary=keep) assert tss.num_populations == 4 def test_removed_node_population_filter(self): @@ -5472,7 +5507,7 @@ def test_removed_node_population_filter(self): tables.nodes.add_row(flags=0, population=1) tables.nodes.add_row(flags=1, population=2) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=True, keep_unary=keep ) assert tss.num_nodes == 2 @@ -5482,7 +5517,7 @@ def test_removed_node_population_filter(self): assert tss.node(0).population == 0 assert tss.node(1).population == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=False, keep_unary=keep ) assert tss.tables.populations == tables.populations @@ -5494,14 +5529,14 @@ def test_simple_individual_filter(self): tables.nodes.add_row(flags=1, individual=0) tables.nodes.add_row(flags=1, individual=0) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=True, keep_unary=keep ) assert tss.num_nodes == 2 assert tss.num_individuals == 1 assert tss.individual(0).flags == 0 - tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=False) + tss, _ = do_simplify(tables.tree_sequence(), filter_individuals=False) assert tss.tables.individuals == tables.individuals def test_interleaved_individual_filter(self): @@ -5513,14 +5548,14 @@ def test_interleaved_individual_filter(self): tables.nodes.add_row(flags=1, individual=-1) tables.nodes.add_row(flags=1, individual=1) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=True, keep_unary=keep ) assert tss.num_nodes == 3 assert tss.num_individuals == 1 assert tss.individual(0).flags == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=False, keep_unary=keep ) assert tss.tables.individuals == tables.individuals @@ -5536,7 +5571,7 @@ def test_removed_node_individual_filter(self): tables.nodes.add_row(flags=0, individual=1) tables.nodes.add_row(flags=1, individual=2) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=True, keep_unary=keep ) assert tss.num_nodes == 2 @@ -5546,13 +5581,13 @@ def test_removed_node_individual_filter(self): assert tss.node(0).individual == 0 assert tss.node(1).individual == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=False, keep_unary=keep ) assert tss.tables.individuals == tables.individuals def verify_simplify_haplotypes(self, ts, samples, keep_unary=False): - sub_ts, node_map = self.do_simplify( + sub_ts, node_map = do_simplify( ts, samples, filter_sites=False, keep_unary=keep_unary ) assert ts.num_sites == sub_ts.num_sites @@ -5643,7 +5678,7 @@ def verify(self, ts): def verify_keep_input_roots(self, ts, samples): ts = tsutil.insert_unique_metadata(ts, ["individuals"]) - ts_with_roots, node_map = self.do_simplify( + ts_with_roots, node_map = do_simplify( ts, samples, keep_input_roots=True, filter_sites=False, compare_lib=True ) new_to_input_map = { @@ -5784,6 +5819,159 @@ def test_many_trees_recurrent_mutations(self): self.verify_keep_input_roots(ts, samples) +class TestSimplifyFilterNodes: + """ + Tests simplify when nodes are kept in the ts with filter_nodes=False + """ + + def reverse_node_indexes(self, ts): + tables = ts.dump_tables() + nodes = tables.nodes + edges = tables.edges + mutations = tables.mutations + nodes.replace_with(nodes[::-1]) + edges.parent = ts.num_nodes - edges.parent - 1 + edges.child = ts.num_nodes - edges.child - 1 + mutations.node = ts.num_nodes - mutations.node - 1 + tables.sort() + return tables.tree_sequence() + + def verify_nodes_unchanged(self, ts_in, resample_size=None, **kwargs): + if resample_size is None: + samples = None + else: + np.random.seed(42) + samples = np.sort( + np.random.choice(ts_in.num_nodes, resample_size, replace=False) + ) + + for ts in (ts_in, self.reverse_node_indexes(ts_in)): + filtered, n_map = do_simplify( + ts, samples=samples, filter_nodes=False, compare_lib=False, **kwargs + ) + assert np.array_equal(n_map, np.arange(ts.num_nodes, dtype=n_map.dtype)) + referenced_nodes = set(filtered.samples()) + referenced_nodes.update(filtered.edges_parent) + referenced_nodes.update(filtered.edges_child) + for n1, n2 in zip(ts.nodes(), filtered.nodes()): + # Ignore the tskit.NODE_IS_SAMPLE flag which can be changed by simplify + if n2.id in referenced_nodes: + assert n_map[n2.id] == tskit.NULL + else: + n1 = n1.replace(flags=n1.flags | tskit.NODE_IS_SAMPLE) + n2 = n2.replace(flags=n2.flags | tskit.NODE_IS_SAMPLE) + assert n1 == n2 + + # Check that edges are identical to the normal simplify(), + # with the normal "simplify" having altered IDs + simplified, node_map = ts.simplify( + samples=samples, map_nodes=True, **kwargs + ) + simplified_edges = {e for e in simplified.tables.edges} + filtered_edges = { + e.replace(parent=node_map[e.parent], child=node_map[e.child]) + for e in filtered.tables.edges + } + assert filtered_edges == simplified_edges + + def test_empty(self): + ts = tskit.TableCollection(1).tree_sequence() + self.verify_nodes_unchanged(ts) + + def test_all_samples(self): + ts = tskit.Tree.generate_comb(5).tree_sequence + tables = ts.dump_tables() + flags = tables.nodes.flags + flags |= tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + assert ts.num_samples == ts.num_nodes + self.verify_nodes_unchanged(ts) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_no_topology(self, resample_size): + ts = tskit.Tree.generate_comb(5).tree_sequence + ts = ts.keep_intervals([], simplify=False) + assert ts.num_nodes > 5 # has unreferenced nodes + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 2]) + def test_stick_tree(self, resample_size): + ts = tskit.Tree.generate_comb(2).tree_sequence + ts = ts.simplify([0], keep_unary=True) + assert ts.first().parent(0) != tskit.NULL + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + # switch to an internal sample + tables = ts.dump_tables() + flags = tables.nodes.flags + flags[0] = 0 + flags[1] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + self.verify_nodes_unchanged(tables.tree_sequence(), resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_internal_samples(self, resample_size): + ts = tskit.Tree.generate_comb(4).tree_sequence + tables = ts.dump_tables() + flags = tables.nodes.flags + flags ^= tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + assert np.all(ts.samples() >= ts.num_samples) + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_blank_flanks(self, resample_size): + ts = tskit.Tree.generate_comb(4).tree_sequence + ts = ts.keep_intervals([[0.25, 0.75]], simplify=False) + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_multiroot(self, resample_size): + ts = tskit.Tree.generate_balanced(6).tree_sequence + ts = ts.decapitate(2.5) + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 10]) + def test_with_metadata(self, ts_fixture_for_simplify, resample_size): + assert ts_fixture_for_simplify.num_nodes > 10 + self.verify_nodes_unchanged( + ts_fixture_for_simplify, resample_size=resample_size + ) + + @pytest.mark.parametrize("resample_size", [None, 7]) + def test_complex_ts_with_unary(self, resample_size): + ts = msprime.sim_ancestry( + 3, + sequence_length=10, + recombination_rate=1, + record_full_arg=True, + random_seed=123, + ) + assert ts.num_trees > 2 + ts = msprime.sim_mutations(ts, rate=1, random_seed=123) + # Add some unreferenced nodes + tables = ts.dump_tables() + tables.nodes.add_row(flags=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE) + ts = tables.tree_sequence() + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + def test_keeping_unary(self): + # Test interaction with keeping unary nodes + n_samples = 6 + ts = tskit.Tree.generate_comb(n_samples).tree_sequence + num_nodes = ts.num_nodes + reduced_n_samples = [2, n_samples - 1] # last sample is most deeply nested + ts_with_unary = ts.simplify(reduced_n_samples, keep_unary=True) + assert ts_with_unary.num_nodes == num_nodes - n_samples + len(reduced_n_samples) + tree = ts_with_unary.first() + assert any([tree.num_children(u) == 1 for u in tree.nodes()]) + self.verify_nodes_unchanged(ts_with_unary, keep_unary=True) + self.verify_nodes_unchanged(ts_with_unary, keep_unary=False) + + class TestMapToAncestors: """ Tests the AncestorMap class. diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 7f70e29a1f..6f3b080ce5 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1932,9 +1932,12 @@ def all_trees_ts(n): return tables.tree_sequence() -def all_fields_ts(): +def all_fields_ts(edge_metadata=True, migrations=True): """ - A tree sequence with data in all fields + A tree sequence with data in all fields (except edge metadata is not set if + edge_metadata is False and migrations are not defined if migrations is False + (this is needed to test simplify, which doesn't allow either) + """ demography = msprime.Demography() demography.add_population(name="A", initial_size=10_000) @@ -1949,7 +1952,7 @@ def all_fields_ts(): sequence_length=5, random_seed=42, recombination_rate=1, - record_migrations=True, + record_migrations=migrations, record_provenance=True, ) ts = msprime.sim_mutations(ts, rate=0.001, random_seed=42) @@ -1973,21 +1976,27 @@ def all_fields_ts(): population=i % len(tables.populations), ) ) - tables.migrations.add_row(left=0, right=1, node=21, source=1, dest=3, time=1001) + if migrations: + tables.migrations.add_row(left=0, right=1, node=21, source=1, dest=3, time=1001) # Add metadata for name, table in tables.table_name_map.items(): - if name != "provenances": - table.metadata_schema = tskit.MetadataSchema.permissive_json() - metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))] - metadata, metadata_offset = tskit.pack_strings(metadatas) - table.set_columns( - **{ - **table.asdict(), - "metadata": metadata, - "metadata_offset": metadata_offset, - } - ) + if name == "provenances": + continue + if name == "migrations" and not migrations: + continue + if name == "edges" and not edge_metadata: + continue + table.metadata_schema = tskit.MetadataSchema.permissive_json() + metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))] + metadata, metadata_offset = tskit.pack_strings(metadatas) + table.set_columns( + **{ + **table.asdict(), + "metadata": metadata, + "metadata_offset": metadata_offset, + } + ) tables.metadata_schema = tskit.MetadataSchema.permissive_json() tables.metadata = "Test metadata" tables.time_units = "Test time units" diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 443b3a5ba8..483186ffc6 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3348,6 +3348,7 @@ def simplify( filter_populations=None, filter_individuals=None, filter_sites=None, + filter_nodes=None, keep_unary=False, keep_unary_in_individuals=None, keep_input_roots=False, @@ -3357,9 +3358,9 @@ def simplify( """ Simplifies the tables in place to retain only the information necessary to reconstruct the tree sequence describing the given ``samples``. - This will change the ID of the nodes, so that the node - ``samples[k]`` will have ID ``k`` in the result. The resulting - NodeTable will have only the first ``len(samples)`` nodes marked + If ``filter_nodes`` is True, this can change the ID of the nodes, so + that the node ``samples[k]`` will have ID ``k`` in the result, resulting + in a NodeTable where only the first ``len(samples)`` nodes are marked as samples. The mapping from node IDs in the current set of tables to their equivalent values in the simplified tables is also returned as a numpy array. If an array ``a`` is returned by this function and ``u`` @@ -3399,6 +3400,11 @@ def simplify( not referenced by mutations after simplification; new site IDs are allocated sequentially from zero. If False, the site table will not be altered in any way. (Default: None, treated as True) + :param bool filter_nodes: If True, remove any nodes that are + not referenced by edges after simplification. If False, the only + potential change to the node table may be to change the node flags + (if ``samples`` is specified and different from the existing samples). + (Default: None, treated as True) :param bool keep_unary: If True, preserve unary nodes (i.e. nodes with exactly one child) that exist on the path from samples to root. (Default: False) @@ -3440,6 +3446,8 @@ def simplify( filter_individuals = True if filter_sites is None: filter_sites = True + if filter_nodes is None: + filter_nodes = True if keep_unary_in_individuals is None: keep_unary_in_individuals = False @@ -3448,6 +3456,7 @@ def simplify( filter_sites=filter_sites, filter_individuals=filter_individuals, filter_populations=filter_populations, + filter_nodes=filter_nodes, reduce_to_site_topology=reduce_to_site_topology, keep_unary=keep_unary, keep_unary_in_individuals=keep_unary_in_individuals, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 49552545c0..ba1d109e03 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6491,6 +6491,7 @@ def simplify( filter_populations=None, filter_individuals=None, filter_sites=None, + filter_nodes=None, keep_unary=False, keep_unary_in_individuals=None, keep_input_roots=False, @@ -6505,11 +6506,14 @@ def simplify( original tree sequence, or :data:`tskit.NULL` (-1) if ``u`` is no longer present in the simplified tree sequence. - In the returned tree sequence, the node with ID ``0`` corresponds to - ``samples[0]``, node ``1`` corresponds to ``samples[1]`` etc., and all - the passed-in nodes are flagged as samples. The remaining node IDs in - the returned tree sequence are allocated sequentially in time order - and are not flagged as samples. + In the returned tree sequence the only nodes flagged as samples are those + passed as ``samples``: all others are not flagged as samples. + If ``filter_nodes`` is not False, nodes in the returned tree sequence are + also reordered such that the node with ID ``0`` corresponds to ``samples[0]``, + node ``1`` corresponds to ``samples[1]`` etc., and the remaining node IDs + are allocated sequentially in time order. Alternatively, if ``filter_nodes`` + is False, the node order is not changed, and the order of IDs passed to + ``samples`` is irrelevant. If you wish to simplify a set of tables that do not satisfy all requirements for building a TreeSequence, then use @@ -6525,12 +6529,12 @@ def simplify( (up to node ID remapping) to the topology of the corresponding tree in the input tree sequence. - If ``filter_populations``, ``filter_individuals`` or ``filter_sites`` is - True, any of the corresponding objects that are not referenced elsewhere - are filtered out. As this is the default behaviour, it is important to - realise IDs for these objects may change through simplification. By setting - these parameters to False, however, the corresponding tables can be preserved - without changes. + If ``filter_populations``, ``filter_individuals``, ``filter_sites``, or + ``filter_nodes`` is True, any of the corresponding objects that are not + referenced elsewhere are filtered out. As this is the default behaviour, + it is important to realise IDs for these objects may change through + simplification. By setting these parameters to False, however, the + corresponding tables can be preserved without changes. :param list[int] samples: A list of node IDs to retain as samples. They need not be nodes marked as samples in the original tree sequence, but @@ -6556,6 +6560,11 @@ def simplify( not referenced by mutations after simplification; new site IDs are allocated sequentially from zero. If False, the site table will not be altered in any way. (Default: None, treated as True) + :param bool filter_nodes: If True, remove any nodes that are + not referenced by edges after simplification. If False, the only + potential change to the node table may be to change the node flags + (if ``samples`` is specified and different from the existing samples). + (Default: None, treated as True) :param bool keep_unary: If True, preserve unary nodes (i.e., nodes with exactly one child) that exist on the path from samples to root. (Default: False) @@ -6587,6 +6596,7 @@ def simplify( filter_populations=filter_populations, filter_individuals=filter_individuals, filter_sites=filter_sites, + filter_nodes=filter_nodes, keep_unary=keep_unary, keep_unary_in_individuals=keep_unary_in_individuals, keep_input_roots=keep_input_roots, From 2b61dfdddff77964384267fe1453f4f9a352912e Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 2 Nov 2022 10:05:23 +0000 Subject: [PATCH 08/84] Update simplify semantics and refactor Revert to old implementation Simple tset Update changelog and add C tests Done? Low level tests More tests Fix error-handling error in simplifier_init We were clearing the input tables before checking sample errors Fixup broken tests Add test for mutations Don't clear the node table when not filtering nodes Refactor Modernise simplify test Fix provenance bug Removed unused samples member Implement filter-populations with no-touch semantics updates Refactor finalise references path Finished no-touch semantics on the non-filter casese Remove unused simplify_t struct member Make simplify thread-safe in no-filter case Update changelog --- c/CHANGELOG.rst | 7 + c/tests/test_tables.c | 22 ++ c/tests/test_trees.c | 67 +++++ c/tskit/tables.c | 504 +++++++++++++++++++-------------- c/tskit/tables.h | 43 ++- python/tests/simplify.py | 46 +-- python/tests/test_highlevel.py | 366 ++++++++++++------------ python/tests/test_lowlevel.py | 20 ++ python/tests/test_tables.py | 4 +- python/tests/test_topology.py | 159 +++++++++-- python/tskit/tables.py | 29 +- python/tskit/trees.py | 16 +- 12 files changed, 832 insertions(+), 451 deletions(-) diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index 461d3d788a..7cf7d42507 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -10,6 +10,13 @@ ``tsk_treeseq_get_min_time`` and ``tsk_treeseq_get_max_time``, respectively. (:user:`szhan`, :pr:`2612`, :issue:`2271`) +- Add the `TSK_SIMPLIFY_NO_FILTER_NODES` option to simplify to allow unreferenced + nodes be kept in the output (:user:`jeromekelleher`, :user:`hyanwong`, + :issue:`2606`, :pr:`2619`). + +- Guarantee that unfiltered tables are not written to unnecessarily + during simplify (:user:`jeromekelleher` :pr:`2619`). + -------------------- [1.1.1] - 2022-07-29 -------------------- diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 8bd7ac606e..1b99ac6fe3 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -345,10 +345,32 @@ test_table_collection_simplify_errors(void) tsk_id_t samples[] = { 0, 1 }; tsk_id_t ret_id; const char *individuals = "1 0.25 -2\n"; + ret = tsk_table_collection_init(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tables.sequence_length = 1; + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + /* Bad samples */ + samples[0] = -1; + ret = tsk_table_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + samples[0] = 10; + ret = tsk_table_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + samples[0] = 0; + + /* Duplicate samples */ + samples[0] = 0; + samples[1] = 0; + ret = tsk_table_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + samples[0] = 0; + ret_id = tsk_site_table_add_row(&tables.sites, 0, "A", 1, NULL, 0); CU_ASSERT_FATAL(ret_id >= 0); ret_id = tsk_site_table_add_row(&tables.sites, 0, "A", 1, NULL, 0); diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 18b904171b..bb70d273d6 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -3257,6 +3257,72 @@ test_simplest_individual_filter(void) tsk_table_collection_free(&tables); } +static void +test_simplest_no_node_filter(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 1 0"; /* unreferenced node */ + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts, simplified; + tsk_id_t sample_ids[] = { 0, 1 }; + tsk_id_t node_map[] = { -1, -1, -1, -1 }; + tsk_id_t j; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_simplify( + &ts, NULL, 0, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + /* Reversing sample order makes no difference */ + sample_ids[0] = 1; + sample_ids[1] = 0; + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify( + &ts, sample_ids, 1, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, node_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(&simplified), 0); + for (j = 0; j < 4; j++) { + CU_ASSERT_EQUAL(node_map[j], j); + } + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 1, + TSK_SIMPLIFY_NO_FILTER_NODES | TSK_SIMPLIFY_KEEP_INPUT_ROOTS + | TSK_SIMPLIFY_KEEP_UNARY, + &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(&simplified), 1); + tsk_treeseq_free(&simplified); + + sample_ids[0] = 0; + sample_ids[1] = 0; + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + tsk_treeseq_free(&simplified); + + tsk_treeseq_free(&ts); +} + static void test_simplest_map_mutations(void) { @@ -8026,6 +8092,7 @@ main(int argc, char **argv) { "test_simplest_simplify_defragment", test_simplest_simplify_defragment }, { "test_simplest_population_filter", test_simplest_population_filter }, { "test_simplest_individual_filter", test_simplest_individual_filter }, + { "test_simplest_no_node_filter", test_simplest_no_node_filter }, { "test_simplest_map_mutations", test_simplest_map_mutations }, { "test_simplest_nonbinary_map_mutations", test_simplest_nonbinary_map_mutations }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 43fd828d2a..ff50e6398d 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -7159,7 +7159,6 @@ typedef struct { } segment_overlapper_t; typedef struct { - tsk_id_t *samples; tsk_size_t num_samples; tsk_flags_t options; tsk_table_collection_t *tables; @@ -7168,6 +7167,7 @@ typedef struct { /* State for topology */ tsk_segment_t **ancestor_map_head; tsk_segment_t **ancestor_map_tail; + /* Mapping of input node IDs to output node IDs. */ tsk_id_t *node_id_map; bool *is_sample; /* Segments for a particular parent that are processed together */ @@ -7185,8 +7185,6 @@ typedef struct { tsk_size_t num_buffered_children; /* For each mutation, map its output node. */ tsk_id_t *mutation_node_map; - /* Map of input mutation IDs to output mutation IDs. */ - tsk_id_t *mutation_id_map; /* Map of input nodes to the list of input mutation IDs */ mutation_id_list_t **node_mutation_list_map_head; mutation_id_list_t **node_mutation_list_map_tail; @@ -8807,7 +8805,7 @@ simplifier_alloc_interval_list(simplifier_t *self, double left, double right) /* Add a new node to the output node table corresponding to the specified input id. * Returns the new ID. */ static tsk_id_t TSK_WARN_UNUSED -simplifier_record_node(simplifier_t *self, tsk_id_t input_id, bool is_sample) +simplifier_record_node(simplifier_t *self, tsk_id_t input_id) { tsk_node_t node; tsk_flags_t flags; @@ -8815,7 +8813,7 @@ simplifier_record_node(simplifier_t *self, tsk_id_t input_id, bool is_sample) tsk_node_table_get_row_unsafe(&self->input_tables.nodes, (tsk_id_t) input_id, &node); /* Zero out the sample bit */ flags = node.flags & (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; - if (is_sample) { + if (self->is_sample[input_id]) { flags |= TSK_NODE_IS_SAMPLE; } self->node_id_map[input_id] = (tsk_id_t) self->tables->nodes.num_rows; @@ -8878,7 +8876,7 @@ simplifier_init_position_lookup(simplifier_t *self) goto out; } self->position_lookup[0] = 0; - self->position_lookup[num_sites + 1] = self->tables->sequence_length; + self->position_lookup[num_sites + 1] = self->input_tables.sequence_length; tsk_memcpy(self->position_lookup + 1, self->input_tables.sites.position, num_sites * sizeof(double)); out: @@ -8922,7 +8920,7 @@ simplifier_record_edge(simplifier_t *self, double left, double right, tsk_id_t c interval_list_t *tail, *x; bool skip; - if (!!(self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY)) { + if (self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY) { skip = simplifier_map_reduced_coordinates(self, &left, &right); /* NOTE: we exit early here when reduce_coordindates has told us to * skip this edge, as it is not visible in the reduced tree sequence */ @@ -8968,8 +8966,6 @@ simplifier_init_sites(simplifier_t *self) mutation_id_list_t *list_node; tsk_size_t j; - self->mutation_id_map - = tsk_calloc(self->input_tables.mutations.num_rows, sizeof(tsk_id_t)); self->mutation_node_map = tsk_calloc(self->input_tables.mutations.num_rows, sizeof(tsk_id_t)); self->node_mutation_list_mem @@ -8978,15 +8974,12 @@ simplifier_init_sites(simplifier_t *self) = tsk_calloc(self->input_tables.nodes.num_rows, sizeof(mutation_id_list_t *)); self->node_mutation_list_map_tail = tsk_calloc(self->input_tables.nodes.num_rows, sizeof(mutation_id_list_t *)); - if (self->mutation_id_map == NULL || self->mutation_node_map == NULL - || self->node_mutation_list_mem == NULL + if (self->mutation_node_map == NULL || self->node_mutation_list_mem == NULL || self->node_mutation_list_map_head == NULL || self->node_mutation_list_map_tail == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - tsk_memset(self->mutation_id_map, 0xff, - self->input_tables.mutations.num_rows * sizeof(tsk_id_t)); tsk_memset(self->mutation_node_map, 0xff, self->input_tables.mutations.num_rows * sizeof(tsk_id_t)); @@ -9060,58 +9053,93 @@ simplifier_add_ancestry( return ret; } +/* Sets up the internal working copies of the various tables, as needed + * depending on the specified options. */ +static int +simplifier_init_tables(simplifier_t *self) +{ + int ret; + bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); + bool filter_populations = self->options & TSK_SIMPLIFY_FILTER_POPULATIONS; + bool filter_individuals = self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS; + bool filter_sites = self->options & TSK_SIMPLIFY_FILTER_SITES; + tsk_bookmark_t rows_to_retain; + + /* NOTE: this is a bit inefficient here as we're taking copies of + * the tables even in the no-filter case where the original tables + * won't be touched (beyond references to external tables that may + * need updating). Future versions may do something a bit more + * complicated like temporarily stealing the pointers to the + * underlying column memory in these tables, and then being careful + * not to free the table at the end. + */ + ret = tsk_table_collection_copy(self->tables, &self->input_tables, 0); + if (ret != 0) { + goto out; + } + memset(&rows_to_retain, 0, sizeof(rows_to_retain)); + rows_to_retain.provenances = self->tables->provenances.num_rows; + if (!filter_nodes) { + rows_to_retain.nodes = self->tables->nodes.num_rows; + } + if (!filter_populations) { + rows_to_retain.populations = self->tables->populations.num_rows; + } + if (!filter_individuals) { + rows_to_retain.individuals = self->tables->individuals.num_rows; + } + if (!filter_sites) { + rows_to_retain.sites = self->tables->sites.num_rows; + } + + ret = tsk_table_collection_truncate(self->tables, &rows_to_retain); + if (ret != 0) { + goto out; + } +out: + return ret; +} + static int simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples) { int ret = 0; tsk_id_t node_id; tsk_size_t j; - tsk_size_t num_nodes = self->input_tables.nodes.num_rows; + const tsk_size_t num_nodes = self->input_tables.nodes.num_rows; bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); - bool is_sample; - - for (j = 0; j < self->num_samples; j++) { - if (samples[j] < 0 || samples[j] > (tsk_id_t) num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; - goto out; - } - if (self->is_sample[samples[j]]) { - ret = TSK_ERR_DUPLICATE_SAMPLE; - goto out; - } - self->is_sample[samples[j]] = true; - } + tsk_flags_t *node_flags = self->tables->nodes.flags; + tsk_id_t *node_id_map = self->node_id_map; if (filter_nodes) { - /* Go through the samples to check for errors. */ + tsk_bug_assert(self->tables->nodes.num_rows == 0); + /* The node table has been cleared. Add nodes for the samples. */ for (j = 0; j < self->num_samples; j++) { - node_id = simplifier_record_node(self, samples[j], true); + node_id = simplifier_record_node(self, samples[j]); if (node_id < 0) { ret = (int) node_id; goto out; } - ret = simplifier_add_ancestry( - self, samples[j], 0, self->tables->sequence_length, node_id); - if (ret != 0) { - goto out; - } } } else { - /* record all the nodes, but only save ancestry for those in the sample */ + tsk_bug_assert(self->tables->nodes.num_rows == num_nodes); + /* The node table has not been changed */ for (j = 0; j < num_nodes; j++) { - is_sample = self->is_sample[j]; - node_id = simplifier_record_node(self, (tsk_id_t) j, is_sample); - if (node_id < 0) { - ret = (int) node_id; - goto out; - } - if (is_sample) { - ret = simplifier_add_ancestry( - self, node_id, 0, self->tables->sequence_length, node_id); - if (ret != 0) { - goto out; - } + /* Reset the sample flags */ + node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; + if (self->is_sample[j]) { + node_flags[j] |= TSK_NODE_IS_SAMPLE; } + node_id_map[j] = (tsk_id_t) j; + } + } + /* Add the initial ancestry */ + for (j = 0; j < self->num_samples; j++) { + node_id = samples[j]; + ret = simplifier_add_ancestry(self, node_id, 0, + self->input_tables.sequence_length, self->node_id_map[node_id]); + if (ret != 0) { + goto out; } } out: @@ -9123,6 +9151,7 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp tsk_table_collection_t *tables, tsk_flags_t options) { int ret = 0; + tsk_size_t j; tsk_id_t ret_id; tsk_size_t num_nodes; @@ -9144,19 +9173,6 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp goto out; } - ret = tsk_table_collection_copy(self->tables, &self->input_tables, 0); - if (ret != 0) { - goto out; - } - - /* Take a copy of the input samples */ - self->samples = tsk_malloc(num_samples * sizeof(tsk_id_t)); - if (self->samples == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tsk_memcpy(self->samples, samples, num_samples * sizeof(tsk_id_t)); - /* Allocate the heaps used for small objects-> Assuming 8K is a good chunk size */ ret = tsk_blkalloc_init(&self->segment_heap, 8192); @@ -9190,12 +9206,25 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp ret = TSK_ERR_NO_MEMORY; goto out; } - ret = tsk_table_collection_clear(self->tables, 0); + + /* Go through the samples to check for errors before we clear the tables. */ + for (j = 0; j < self->num_samples; j++) { + if (samples[j] < 0 || samples[j] >= (tsk_id_t) num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (self->is_sample[samples[j]]) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + self->is_sample[samples[j]] = true; + } + tsk_memset(self->node_id_map, 0xff, num_nodes * sizeof(tsk_id_t)); + + ret = simplifier_init_tables(self); if (ret != 0) { goto out; } - tsk_memset( - self->node_id_map, 0xff, self->input_tables.nodes.num_rows * sizeof(tsk_id_t)); ret = simplifier_init_sites(self); if (ret != 0) { goto out; @@ -9204,12 +9233,13 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp if (ret != 0) { goto out; } - if (!!(self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY)) { + if (self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY) { ret = simplifier_init_position_lookup(self); if (ret != 0) { goto out; } } + self->edge_sort_offset = TSK_NULL; out: return ret; @@ -9222,7 +9252,6 @@ simplifier_free(simplifier_t *self) tsk_blkalloc_free(&self->segment_heap); tsk_blkalloc_free(&self->interval_list_heap); segment_overlapper_free(&self->segment_overlapper); - tsk_safe_free(self->samples); tsk_safe_free(self->ancestor_map_head); tsk_safe_free(self->ancestor_map_tail); tsk_safe_free(self->child_edge_map_head); @@ -9230,7 +9259,6 @@ simplifier_free(simplifier_t *self) tsk_safe_free(self->node_id_map); tsk_safe_free(self->segment_queue); tsk_safe_free(self->is_sample); - tsk_safe_free(self->mutation_id_map); tsk_safe_free(self->mutation_node_map); tsk_safe_free(self->node_mutation_list_mem); tsk_safe_free(self->node_mutation_list_map_head); @@ -9278,11 +9306,10 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) double left, right, prev_right; tsk_id_t ancestry_node; tsk_id_t output_id = self->node_id_map[input_id]; - bool is_sample = self->is_sample[input_id]; - /* bool is_sample = output_id != TSK_NULL; */ bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); - bool keep_unary = !!(self->options & TSK_SIMPLIFY_KEEP_UNARY); + bool keep_unary = self->options & TSK_SIMPLIFY_KEEP_UNARY; + if ((self->options & TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS) && (self->input_tables.nodes.individual[input_id] != TSK_NULL)) { keep_unary = true; @@ -9317,7 +9344,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) ancestry_node = output_id; } else if (keep_unary) { if (output_id == TSK_NULL) { - output_id = simplifier_record_node(self, input_id, false); + output_id = simplifier_record_node(self, input_id); } ret = simplifier_record_edge(self, left, right, ancestry_node); if (ret != 0) { @@ -9326,7 +9353,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) } } else { if (output_id == TSK_NULL) { - output_id = simplifier_record_node(self, input_id, false); + output_id = simplifier_record_node(self, input_id); if (output_id < 0) { ret = (int) output_id; goto out; @@ -9479,133 +9506,60 @@ simplifier_process_parent_edges( } static int TSK_WARN_UNUSED -simplifier_output_sites(simplifier_t *self) +simplifier_finalise_site_references( + simplifier_t *self, const bool *site_referenced, tsk_id_t *site_id_map) { int ret = 0; tsk_id_t ret_id; - tsk_id_t input_site; - tsk_id_t input_mutation, mapped_parent, site_start, site_end; - tsk_id_t num_input_sites = (tsk_id_t) self->input_tables.sites.num_rows; - tsk_id_t num_input_mutations = (tsk_id_t) self->input_tables.mutations.num_rows; - tsk_id_t num_output_mutations, num_output_site_mutations; - tsk_id_t mapped_node; - bool keep_site; - bool filter_sites = !!(self->options & TSK_SIMPLIFY_FILTER_SITES); + tsk_size_t j; tsk_site_t site; - tsk_mutation_t mutation; - - input_mutation = 0; - num_output_mutations = 0; - for (input_site = 0; input_site < num_input_sites; input_site++) { - tsk_site_table_get_row_unsafe( - &self->input_tables.sites, (tsk_id_t) input_site, &site); - site_start = input_mutation; - num_output_site_mutations = 0; - while (input_mutation < num_input_mutations - && self->input_tables.mutations.site[input_mutation] == site.id) { - mapped_node = self->mutation_node_map[input_mutation]; - if (mapped_node != TSK_NULL) { - self->mutation_id_map[input_mutation] = num_output_mutations; - num_output_mutations++; - num_output_site_mutations++; - } - input_mutation++; - } - site_end = input_mutation; - - keep_site = true; - if (filter_sites && num_output_site_mutations == 0) { - keep_site = false; - } - if (keep_site) { - for (input_mutation = site_start; input_mutation < site_end; - input_mutation++) { - if (self->mutation_id_map[input_mutation] != TSK_NULL) { - tsk_bug_assert( - self->tables->mutations.num_rows - == (tsk_size_t) self->mutation_id_map[input_mutation]); - mapped_node = self->mutation_node_map[input_mutation]; - tsk_bug_assert(mapped_node != TSK_NULL); - mapped_parent = self->input_tables.mutations.parent[input_mutation]; - if (mapped_parent != TSK_NULL) { - mapped_parent = self->mutation_id_map[mapped_parent]; - } - tsk_mutation_table_get_row_unsafe(&self->input_tables.mutations, - (tsk_id_t) input_mutation, &mutation); - ret_id = tsk_mutation_table_add_row(&self->tables->mutations, - (tsk_id_t) self->tables->sites.num_rows, mapped_node, - mapped_parent, mutation.time, mutation.derived_state, - mutation.derived_state_length, mutation.metadata, - mutation.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; - } + const tsk_size_t num_sites = self->input_tables.sites.num_rows; + + if (self->options & TSK_SIMPLIFY_FILTER_SITES) { + for (j = 0; j < num_sites; j++) { + tsk_site_table_get_row_unsafe( + &self->input_tables.sites, (tsk_id_t) j, &site); + site_id_map[j] = TSK_NULL; + if (site_referenced[j]) { + ret_id = tsk_site_table_add_row(&self->tables->sites, site.position, + site.ancestral_state, site.ancestral_state_length, site.metadata, + site.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; } - } - ret_id = tsk_site_table_add_row(&self->tables->sites, site.position, - site.ancestral_state, site.ancestral_state_length, site.metadata, - site.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; + site_id_map[j] = ret_id; } } - tsk_bug_assert( - num_output_mutations == (tsk_id_t) self->tables->mutations.num_rows); - input_mutation = site_end; + } else { + tsk_bug_assert(self->tables->sites.num_rows == num_sites); + for (j = 0; j < num_sites; j++) { + site_id_map[j] = (tsk_id_t) j; + } } - tsk_bug_assert(input_mutation == num_input_mutations); - ret = 0; out: return ret; } static int TSK_WARN_UNUSED -simplifier_finalise_references(simplifier_t *self) +simplifier_finalise_population_references(simplifier_t *self) { int ret = 0; - tsk_id_t ret_id; tsk_size_t j; - bool keep; - tsk_size_t num_nodes = self->tables->nodes.num_rows; - + tsk_id_t pop_id, ret_id; tsk_population_t pop; - tsk_id_t pop_id; - tsk_size_t num_populations = self->input_tables.populations.num_rows; tsk_id_t *node_population = self->tables->nodes.population; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_size_t num_populations = self->input_tables.populations.num_rows; bool *population_referenced = tsk_calloc(num_populations, sizeof(*population_referenced)); tsk_id_t *population_id_map = tsk_malloc(num_populations * sizeof(*population_id_map)); - bool filter_populations = !!(self->options & TSK_SIMPLIFY_FILTER_POPULATIONS); - tsk_individual_t ind; - tsk_id_t ind_id; - tsk_size_t num_individuals = self->input_tables.individuals.num_rows; - tsk_id_t *node_individual = self->tables->nodes.individual; - bool *individual_referenced - = tsk_calloc(num_individuals, sizeof(*individual_referenced)); - tsk_id_t *individual_id_map - = tsk_malloc(num_individuals * sizeof(*individual_id_map)); - bool filter_individuals = !!(self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS); + tsk_bug_assert(self->options & TSK_SIMPLIFY_FILTER_POPULATIONS); - if (population_referenced == NULL || population_id_map == NULL - || individual_referenced == NULL || individual_id_map == NULL) { - goto out; - } - - /* TODO Migrations fit reasonably neatly into the pattern that we have here. We - * can consider references to populations from migration objects in the same way - * as from nodes, so that we only remove a population if its referenced by - * neither. Mapping the population IDs in migrations is then easy. In principle - * nodes are similar, but the semantics are slightly different because we've - * already allocated all the nodes by their references from edges. We then - * need to decide whether we remove migrations that reference unmapped nodes - * or whether to add these nodes back in (probably the former is the correct - * approach).*/ - if (self->input_tables.migrations.num_rows != 0) { - ret = TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED; + if (population_referenced == NULL || population_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; goto out; } @@ -9614,20 +9568,13 @@ simplifier_finalise_references(simplifier_t *self) if (pop_id != TSK_NULL) { population_referenced[pop_id] = true; } - ind_id = node_individual[j]; - if (ind_id != TSK_NULL) { - individual_referenced[ind_id] = true; - } } + for (j = 0; j < num_populations; j++) { tsk_population_table_get_row_unsafe( &self->input_tables.populations, (tsk_id_t) j, &pop); - keep = true; - if (filter_populations && !population_referenced[j]) { - keep = false; - } population_id_map[j] = TSK_NULL; - if (keep) { + if (population_referenced[j]) { ret_id = tsk_population_table_add_row( &self->tables->populations, pop.metadata, pop.metadata_length); if (ret_id < 0) { @@ -9638,15 +9585,56 @@ simplifier_finalise_references(simplifier_t *self) } } + /* Remap the IDs in the node table */ + for (j = 0; j < num_nodes; j++) { + pop_id = node_population[j]; + if (pop_id != TSK_NULL) { + node_population[j] = population_id_map[pop_id]; + } + } +out: + tsk_safe_free(population_id_map); + tsk_safe_free(population_referenced); + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_finalise_individual_references(simplifier_t *self) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t pop_id, ret_id; + tsk_individual_t ind; + tsk_id_t *node_individual = self->tables->nodes.individual; + tsk_id_t *parents; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_size_t num_individuals = self->input_tables.individuals.num_rows; + bool *individual_referenced + = tsk_calloc(num_individuals, sizeof(*individual_referenced)); + tsk_id_t *individual_id_map + = tsk_malloc(num_individuals * sizeof(*individual_id_map)); + + tsk_bug_assert(self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS); + + if (individual_referenced == NULL || individual_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (j = 0; j < num_nodes; j++) { + pop_id = node_individual[j]; + if (pop_id != TSK_NULL) { + individual_referenced[pop_id] = true; + } + } + for (j = 0; j < num_individuals; j++) { tsk_individual_table_get_row_unsafe( &self->input_tables.individuals, (tsk_id_t) j, &ind); - keep = true; - if (filter_individuals && !individual_referenced[j]) { - keep = false; - } individual_id_map[j] = TSK_NULL; - if (keep) { + if (individual_referenced[j]) { + /* Can't remap the parents inline here because we have no + * guarantees about sortedness */ ret_id = tsk_individual_table_add_row(&self->tables->individuals, ind.flags, ind.location, ind.location_length, ind.parents, ind.parents_length, ind.metadata, ind.metadata_length); @@ -9658,32 +9646,128 @@ simplifier_finalise_references(simplifier_t *self) } } - /* Remap parent IDs */ - for (j = 0; j < self->tables->individuals.parents_length; j++) { - self->tables->individuals.parents[j] - = self->tables->individuals.parents[j] == TSK_NULL - ? TSK_NULL - : individual_id_map[self->tables->individuals.parents[j]]; - } - - /* Remap node IDs referencing the above */ + /* Remap the IDs in the node table */ for (j = 0; j < num_nodes; j++) { - pop_id = node_population[j]; + pop_id = node_individual[j]; if (pop_id != TSK_NULL) { - node_population[j] = population_id_map[pop_id]; + node_individual[j] = individual_id_map[pop_id]; } - ind_id = node_individual[j]; - if (ind_id != TSK_NULL) { - node_individual[j] = individual_id_map[ind_id]; + } + + /* Remap parent IDs. * + * NOTE! must take the pointer reference here as it can change from + * the start of the function */ + parents = self->tables->individuals.parents; + for (j = 0; j < self->tables->individuals.parents_length; j++) { + if (parents[j] != TSK_NULL) { + parents[j] = individual_id_map[parents[j]]; } } - ret = 0; out: - tsk_safe_free(population_referenced); - tsk_safe_free(individual_referenced); - tsk_safe_free(population_id_map); tsk_safe_free(individual_id_map); + tsk_safe_free(individual_referenced); + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_output_sites(simplifier_t *self) +{ + int ret = 0; + tsk_id_t ret_id; + tsk_size_t j; + tsk_mutation_t mutation; + const tsk_size_t num_sites = self->input_tables.sites.num_rows; + const tsk_size_t num_mutations = self->input_tables.mutations.num_rows; + bool *site_referenced = tsk_calloc(num_sites, sizeof(*site_referenced)); + tsk_id_t *site_id_map = tsk_malloc(num_sites * sizeof(*site_id_map)); + tsk_id_t *mutation_id_map = tsk_malloc(num_mutations * sizeof(*mutation_id_map)); + const tsk_id_t *mutation_node_map = self->mutation_node_map; + const tsk_id_t *mutation_site = self->input_tables.mutations.site; + + if (site_referenced == NULL || site_id_map == NULL || mutation_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (j = 0; j < num_mutations; j++) { + if (mutation_node_map[j] != TSK_NULL) { + site_referenced[mutation_site[j]] = true; + } + } + ret = simplifier_finalise_site_references(self, site_referenced, site_id_map); + if (ret != 0) { + goto out; + } + + for (j = 0; j < num_mutations; j++) { + mutation_id_map[j] = TSK_NULL; + if (mutation_node_map[j] != TSK_NULL) { + tsk_mutation_table_get_row_unsafe( + &self->input_tables.mutations, (tsk_id_t) j, &mutation); + mutation.node = mutation_node_map[j]; + mutation.site = site_id_map[mutation.site]; + if (mutation.parent != TSK_NULL) { + mutation.parent = mutation_id_map[mutation.parent]; + } + ret_id = tsk_mutation_table_add_row(&self->tables->mutations, mutation.site, + mutation.node, mutation.parent, mutation.time, mutation.derived_state, + mutation.derived_state_length, mutation.metadata, + mutation.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + mutation_id_map[j] = ret_id; + } + } +out: + tsk_safe_free(site_referenced); + tsk_safe_free(site_id_map); + tsk_safe_free(mutation_id_map); + return ret; +} + +/* Flush the remaining non-edge and node data in the model to the + * output tables. */ +static int TSK_WARN_UNUSED +simplifier_flush_output(simplifier_t *self) +{ + int ret = 0; + + /* TODO Migrations fit reasonably neatly into the pattern that we have here. We + * can consider references to populations from migration objects in the same way + * as from nodes, so that we only remove a population if its referenced by + * neither. Mapping the population IDs in migrations is then easy. In principle + * nodes are similar, but the semantics are slightly different because we've + * already allocated all the nodes by their references from edges. We then + * need to decide whether we remove migrations that reference unmapped nodes + * or whether to add these nodes back in (probably the former is the correct + * approach).*/ + if (self->input_tables.migrations.num_rows != 0) { + ret = TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + + ret = simplifier_output_sites(self); + if (ret != 0) { + goto out; + } + + if (self->options & TSK_SIMPLIFY_FILTER_POPULATIONS) { + ret = simplifier_finalise_population_references(self); + if (ret != 0) { + goto out; + } + } + if (self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS) { + ret = simplifier_finalise_individual_references(self); + if (ret != 0) { + goto out; + } + } + +out: return ret; } @@ -9732,7 +9816,7 @@ simplifier_insert_input_roots(simplifier_t *self) if (x != NULL) { output_id = self->node_id_map[input_id]; if (output_id == TSK_NULL) { - output_id = simplifier_record_node(self, input_id, false); + output_id = simplifier_record_node(self, input_id); if (output_id < 0) { ret = (int) output_id; goto out; @@ -9797,11 +9881,7 @@ simplifier_run(simplifier_t *self, tsk_id_t *node_map) goto out; } } - ret = simplifier_output_sites(self); - if (ret != 0) { - goto out; - } - ret = simplifier_finalise_references(self); + ret = simplifier_flush_output(self); if (ret != 0) { goto out; } diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 8fc0bc4cfb..bd69b9cc95 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2017-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -686,6 +686,13 @@ reference them. */ #define TSK_SIMPLIFY_FILTER_POPULATIONS (1 << 1) /** Remove individuals from the output if there are no nodes that reference them.*/ #define TSK_SIMPLIFY_FILTER_INDIVIDUALS (1 << 2) +/** Do not remove nodes from the output if there are no edges that reference +them and do not reorder nodes so that the samples are nodes 0 to num_samples - 1. +Note that this flag is negated compared to other filtering options because +the default behaviour is to filter unreferenced nodes and reorder to put samples +first. +*/ +#define TSK_SIMPLIFY_NO_FILTER_NODES (1 << 7) /** Reduce the topological information in the tables to the minimum necessary to represent the trees that contain sites. If there are zero sites this will @@ -715,10 +722,6 @@ flag). It keeps unary nodes, but only if the unary node is referenced from an in @endrst */ #define TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS (1 << 6) -/** Retain nodes in the output even if no edges reference them. This is negated -compared to the other TSK_SIMPLIFY_FILTER_XXX flags to preserve previous behaviour. -*/ -#define TSK_SIMPLIFY_NO_FILTER_NODES (1 << 7) /** @} */ /** @@ -3913,8 +3916,32 @@ A mapping from the node IDs in the table before simplification to their equivale values after simplification can be obtained via the ``node_map`` argument. If this is non NULL, ``node_map[u]`` will contain the new ID for node ``u`` after simplification, or :c:macro:`TSK_NULL` if the node has been removed. Thus, ``node_map`` must be an array -of at least ``self->nodes.num_rows`` :c:type:`tsk_id_t` values. The table collection will -always be unindexed after simplify successfully completes. +of at least ``self->nodes.num_rows`` :c:type:`tsk_id_t` values. + +If the `TSK_SIMPLIFY_NO_FILTER_NODES` option is specified, the node table will be +unaltered except for changing the sample status of nodes that were samples in the +input tables, but not in the specified list of sample IDs (if provided). The +``node_map`` (if specified) will always be the identity mapping, such that +``node_map[u] == u`` for all nodes. Note also that the order of the list of +samples is not important in this case. + +When a table is not filtered (i.e., if the `TSK_SIMPLIFY_NO_FILTER_NODES` +option is provided or the `TSK_SIMPLIFY_FILTER_SITES`, +`TSK_SIMPLIFY_FILTER_POPULATIONS` or `TSK_SIMPLIFY_FILTER_INDIVIDUALS` +options are *not* provided) the corresponding table is modified as +little as possible, and all pointers are guaranteed to remain valid +after simplification. The only changes made to an unfiltered table are +to update any references to tables that may have changed (for example, +remapping population IDs in the node table if +`TSK_SIMPLIFY_FILTER_POPULATIONS` was specified) or altering the +sample status flag of nodes. + +.. note:: It is possible for populations and individuals to be filtered + even if `TSK_SIMPLIFY_NO_FILTER_NODES` is specified because there + may be entirely unreferenced entities in the input tables, which + are not affected by whether we filter nodes or not. + +The table collection will always be unindexed after simplify successfully completes. .. note:: Migrations are currently not supported by simplify, and an error will be raised if we attempt call simplify on a table collection with greater @@ -3928,11 +3955,11 @@ Options can be specified by providing one or more of the following bitwise - :c:macro:`TSK_SIMPLIFY_FILTER_SITES` - :c:macro:`TSK_SIMPLIFY_FILTER_POPULATIONS` - :c:macro:`TSK_SIMPLIFY_FILTER_INDIVIDUALS` +- :c:macro:`TSK_SIMPLIFY_NO_FILTER_NODES` - :c:macro:`TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY` - :c:macro:`TSK_SIMPLIFY_KEEP_UNARY` - :c:macro:`TSK_SIMPLIFY_KEEP_INPUT_ROOTS` - :c:macro:`TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS` -- :c:macro:`TSK_SIMPLIFY_NO_FILTER_NODES` @endrst @param self A pointer to a tsk_table_collection_t object. diff --git a/python/tests/simplify.py b/python/tests/simplify.py index e27604bf2c..ef6ebdc25f 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -130,33 +130,48 @@ def __init__( self.A_tail = [None for _ in range(ts.num_nodes)] self.tables = self.ts.tables.copy() self.tables.clear() - if not filter_nodes: - # NOTE: this is hack-ish. - # So far, we have copied the tables once, - # cleared them, and then re-copied. - self.tables = self.ts.tables.copy() self.edge_buffer = {} self.node_id_map = np.zeros(ts.num_nodes, dtype=np.int32) - 1 + self.is_sample = np.zeros(ts.num_nodes, dtype=np.int8) self.mutation_node_map = [-1 for _ in range(self.num_mutations)] self.samples = set(sample) self.sort_offset = -1 # We keep a map of input nodes to mutations. self.mutation_map = [[] for _ in range(ts.num_nodes)] - position = ts.tables.sites.position - site = ts.tables.mutations.site - node = ts.tables.mutations.node + position = ts.sites_position + site = ts.mutations_site + node = ts.mutations_node for mutation_id in range(ts.num_mutations): site_position = position[site[mutation_id]] self.mutation_map[node[mutation_id]].append((site_position, mutation_id)) + for sample_id in sample: - output_id = self.record_node(sample_id, is_sample=True) - self.add_ancestry(sample_id, 0, self.sequence_length, output_id) + self.is_sample[sample_id] = 1 + + if not self.filter_nodes: + # NOTE In the C implementation we would really just not touch the + # original tables. + self.tables.nodes.replace_with(self.ts.tables.nodes) + # TODO make this optional somehow + flags = self.tables.nodes.flags + # Zero out other sample flags + flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE) + flags[sample] |= tskit.NODE_IS_SAMPLE + self.tables.nodes.flags = flags.astype(np.uint32) + self.node_id_map[:] = np.arange(ts.num_nodes) + + for sample_id in sample: + self.add_ancestry(sample_id, 0, self.sequence_length, sample_id) + else: + for sample_id in sample: + output_id = self.record_node(sample_id) + self.add_ancestry(sample_id, 0, self.sequence_length, output_id) self.position_lookup = None if self.reduce_to_site_topology: self.position_lookup = np.hstack([[0], position, [self.sequence_length]]) - def record_node(self, input_id, is_sample=False): + def record_node(self, input_id): """ Adds a new node to the output table corresponding to the specified input node ID. @@ -165,11 +180,8 @@ def record_node(self, input_id, is_sample=False): flags = node.flags # Need to zero out the sample flag flags &= ~tskit.NODE_IS_SAMPLE - if is_sample: + if self.is_sample[input_id]: flags |= tskit.NODE_IS_SAMPLE - if not self.filter_nodes: - self.node_id_map[input_id] = input_id - return input_id output_id = self.tables.nodes.append(node.replace(flags=flags)) self.node_id_map[input_id] = output_id return output_id @@ -286,7 +298,7 @@ def merge_labeled_ancestors(self, S, input_id): The new parent must be assigned and any overlapping segments coalesced. """ output_id = self.node_id_map[input_id] - is_sample = output_id != -1 + is_sample = self.is_sample[input_id] if is_sample: # Free up the existing ancestry mapping. x = self.A_tail[input_id] @@ -330,7 +342,7 @@ def merge_labeled_ancestors(self, S, input_id): self.add_ancestry(input_id, prev_right, self.sequence_length, output_id) if output_id != -1: num_edges = self.flush_edges() - if num_edges == 0 and not is_sample: + if self.filter_nodes and num_edges == 0 and not is_sample: self.rewind_node(input_id, output_id) def extract_ancestry(self, edge): diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 841ccc1e09..65d17cd9f2 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1828,181 +1828,6 @@ def test_max_root_time_corner_cases(self): tables.edges.add_row(0, 1, 3, 1) assert tables.tree_sequence().max_root_time == 3 - def verify_simplify_provenance(self, ts): - new_ts = ts.simplify() - assert new_ts.num_provenances == ts.num_provenances + 1 - old = list(ts.provenances()) - new = list(new_ts.provenances()) - assert old == new[:-1] - # TODO call verify_provenance on this. - assert len(new[-1].timestamp) > 0 - assert len(new[-1].record) > 0 - - new_ts = ts.simplify(record_provenance=False) - assert new_ts.tables.provenances == ts.tables.provenances - - def verify_simplify_topology(self, ts, sample): - new_ts, node_map = ts.simplify(sample, map_nodes=True) - if len(sample) == 0: - assert new_ts.num_nodes == 0 - assert new_ts.num_edges == 0 - assert new_ts.num_sites == 0 - assert new_ts.num_mutations == 0 - elif len(sample) == 1: - assert new_ts.num_nodes == 1 - assert new_ts.num_edges == 0 - # The output samples should be 0...n - assert new_ts.num_samples == len(sample) - assert list(range(len(sample))) == list(new_ts.samples()) - for j in range(new_ts.num_samples): - assert node_map[sample[j]] == j - for u in range(ts.num_nodes): - old_node = ts.node(u) - if node_map[u] != tskit.NULL: - new_node = new_ts.node(node_map[u]) - assert old_node.time == new_node.time - assert old_node.population == new_node.population - assert old_node.metadata == new_node.metadata - for u in sample: - old_node = ts.node(u) - new_node = new_ts.node(node_map[u]) - assert old_node.flags == new_node.flags - assert old_node.time == new_node.time - assert old_node.population == new_node.population - assert old_node.metadata == new_node.metadata - old_trees = ts.trees() - old_tree = next(old_trees) - assert ts.get_num_trees() >= new_ts.get_num_trees() - for new_tree in new_ts.trees(): - new_left, new_right = new_tree.get_interval() - old_left, old_right = old_tree.get_interval() - # Skip ahead on the old tree until new_left is within its interval - while old_right <= new_left: - old_tree = next(old_trees) - old_left, old_right = old_tree.get_interval() - # If the MRCA of all pairs of samples is the same, then we have the - # same information. We limit this to at most 500 pairs - pairs = itertools.islice(itertools.combinations(sample, 2), 500) - for pair in pairs: - mapped_pair = [node_map[u] for u in pair] - mrca1 = old_tree.get_mrca(*pair) - mrca2 = new_tree.get_mrca(*mapped_pair) - if mrca1 == tskit.NULL: - assert mrca2 == mrca1 - else: - assert mrca2 == node_map[mrca1] - assert old_tree.get_time(mrca1) == new_tree.get_time(mrca2) - assert old_tree.get_population(mrca1) == new_tree.get_population( - mrca2 - ) - - def verify_simplify_equality(self, ts, sample): - for filter_sites in [False, True]: - s1, node_map1 = ts.simplify( - sample, map_nodes=True, filter_sites=filter_sites - ) - t1 = s1.dump_tables() - s2, node_map2 = simplify_tree_sequence( - ts, sample, filter_sites=filter_sites - ) - t2 = s2.dump_tables() - assert s1.num_samples == len(sample) - assert s2.num_samples == len(sample) - assert all(node_map1 == node_map2) - assert t1.individuals == t2.individuals - assert t1.nodes == t2.nodes - assert t1.edges == t2.edges - assert t1.migrations == t2.migrations - assert t1.sites == t2.sites - assert t1.mutations == t2.mutations - assert t1.populations == t2.populations - - def verify_simplify_variants(self, ts, sample): - subset = ts.simplify(sample) - sample_map = {u: j for j, u in enumerate(ts.samples())} - # Need to map IDs back to their sample indexes - s = np.array([sample_map[u] for u in sample]) - # Build a map of genotypes by position - full_genotypes = {} - for variant in ts.variants(isolated_as_missing=False): - alleles = [variant.alleles[g] for g in variant.genotypes] - full_genotypes[variant.position] = alleles - for variant in subset.variants(isolated_as_missing=False): - if variant.position in full_genotypes: - a1 = [full_genotypes[variant.position][u] for u in s] - a2 = [variant.alleles[g] for g in variant.genotypes] - assert a1 == a2 - - def verify_tables_api_equality(self, ts): - for samples in [None, list(ts.samples()), ts.samples()]: - tables = ts.dump_tables() - tables.simplify(samples=samples) - tables.assert_equals( - ts.simplify(samples=samples).tables, ignore_timestamps=True - ) - - @pytest.mark.slow - def test_simplify(self): - num_mutations = 0 - for ts in get_example_tree_sequences(pytest_params=False): - # Can't simplify edges with metadata - if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): - self.verify_tables_api_equality(ts) - self.verify_simplify_provenance(ts) - n = ts.num_samples - num_mutations += ts.num_mutations - sample_sizes = {0} - if n > 1: - sample_sizes |= {1} - if n > 2: - sample_sizes |= {2, max(2, n // 2), n - 1} - for k in sample_sizes: - subset = random.sample(list(ts.samples()), k) - self.verify_simplify_topology(ts, subset) - self.verify_simplify_equality(ts, subset) - self.verify_simplify_variants(ts, subset) - assert num_mutations > 0 - - def test_simplify_bugs(self): - prefix = os.path.join(os.path.dirname(__file__), "data", "simplify-bugs") - j = 1 - while True: - nodes_file = os.path.join(prefix, f"{j:02d}-nodes.txt") - if not os.path.exists(nodes_file): - break - edges_file = os.path.join(prefix, f"{j:02d}-edges.txt") - sites_file = os.path.join(prefix, f"{j:02d}-sites.txt") - mutations_file = os.path.join(prefix, f"{j:02d}-mutations.txt") - with open(nodes_file) as nodes, open(edges_file) as edges, open( - sites_file - ) as sites, open(mutations_file) as mutations: - ts = tskit.load_text( - nodes=nodes, - edges=edges, - sites=sites, - mutations=mutations, - strict=False, - ) - samples = list(ts.samples()) - self.verify_simplify_equality(ts, samples) - j += 1 - assert j > 1 - - def test_simplify_migrations_fails(self): - ts = msprime.simulate( - population_configurations=[ - msprime.PopulationConfiguration(10), - msprime.PopulationConfiguration(10), - ], - migration_matrix=[[0, 1], [1, 0]], - random_seed=2, - record_migrations=True, - ) - assert ts.num_migrations > 0 - # We don't support simplify with migrations, so should fail. - with pytest.raises(_tskit.LibraryError): - ts.simplify() - def test_subset_reverse_all_nodes(self): ts = tskit.Tree.generate_comb(5).tree_sequence assert np.all(ts.samples() == np.arange(ts.num_samples)) @@ -2770,6 +2595,197 @@ def test_arrays_equal_to_tables(self, ts_fixture): ) +class TestSimplify: + # This class was factored out of the old TestHighlevel class 2022-12-13, + # and is a mishmash of different testing paradigms. There is some valuable + # testing done here, so it would be good to fully bring it up to date. + + def verify_simplify_provenance(self, ts): + new_ts = ts.simplify() + assert new_ts.num_provenances == ts.num_provenances + 1 + old = list(ts.provenances()) + new = list(new_ts.provenances()) + assert old == new[:-1] + # TODO call verify_provenance on this. + assert len(new[-1].timestamp) > 0 + assert len(new[-1].record) > 0 + + new_ts = ts.simplify(record_provenance=False) + assert new_ts.tables.provenances == ts.tables.provenances + + def verify_simplify_topology(self, ts, sample): + new_ts, node_map = ts.simplify(sample, map_nodes=True) + if len(sample) == 0: + assert new_ts.num_nodes == 0 + assert new_ts.num_edges == 0 + assert new_ts.num_sites == 0 + assert new_ts.num_mutations == 0 + elif len(sample) == 1: + assert new_ts.num_nodes == 1 + assert new_ts.num_edges == 0 + # The output samples should be 0...n + assert new_ts.num_samples == len(sample) + assert list(range(len(sample))) == list(new_ts.samples()) + for j in range(new_ts.num_samples): + assert node_map[sample[j]] == j + for u in range(ts.num_nodes): + old_node = ts.node(u) + if node_map[u] != tskit.NULL: + new_node = new_ts.node(node_map[u]) + assert old_node.time == new_node.time + assert old_node.population == new_node.population + assert old_node.metadata == new_node.metadata + for u in sample: + old_node = ts.node(u) + new_node = new_ts.node(node_map[u]) + assert old_node.flags == new_node.flags + assert old_node.time == new_node.time + assert old_node.population == new_node.population + assert old_node.metadata == new_node.metadata + old_trees = ts.trees() + old_tree = next(old_trees) + assert ts.get_num_trees() >= new_ts.get_num_trees() + for new_tree in new_ts.trees(): + new_left, new_right = new_tree.get_interval() + old_left, old_right = old_tree.get_interval() + # Skip ahead on the old tree until new_left is within its interval + while old_right <= new_left: + old_tree = next(old_trees) + old_left, old_right = old_tree.get_interval() + # If the MRCA of all pairs of samples is the same, then we have the + # same information. We limit this to at most 500 pairs + pairs = itertools.islice(itertools.combinations(sample, 2), 500) + for pair in pairs: + mapped_pair = [node_map[u] for u in pair] + mrca1 = old_tree.get_mrca(*pair) + mrca2 = new_tree.get_mrca(*mapped_pair) + if mrca1 == tskit.NULL: + assert mrca2 == mrca1 + else: + assert mrca2 == node_map[mrca1] + assert old_tree.get_time(mrca1) == new_tree.get_time(mrca2) + assert old_tree.get_population(mrca1) == new_tree.get_population( + mrca2 + ) + + def verify_simplify_equality(self, ts, sample): + for filter_sites in [False, True]: + s1, node_map1 = ts.simplify( + sample, map_nodes=True, filter_sites=filter_sites + ) + t1 = s1.dump_tables() + s2, node_map2 = simplify_tree_sequence( + ts, sample, filter_sites=filter_sites + ) + t2 = s2.dump_tables() + assert s1.num_samples == len(sample) + assert s2.num_samples == len(sample) + assert all(node_map1 == node_map2) + assert t1.individuals == t2.individuals + assert t1.nodes == t2.nodes + assert t1.edges == t2.edges + assert t1.migrations == t2.migrations + assert t1.sites == t2.sites + assert t1.mutations == t2.mutations + assert t1.populations == t2.populations + + def verify_simplify_variants(self, ts, sample): + subset = ts.simplify(sample) + sample_map = {u: j for j, u in enumerate(ts.samples())} + # Need to map IDs back to their sample indexes + s = np.array([sample_map[u] for u in sample]) + # Build a map of genotypes by position + full_genotypes = {} + for variant in ts.variants(isolated_as_missing=False): + alleles = [variant.alleles[g] for g in variant.genotypes] + full_genotypes[variant.position] = alleles + for variant in subset.variants(isolated_as_missing=False): + if variant.position in full_genotypes: + a1 = [full_genotypes[variant.position][u] for u in s] + a2 = [variant.alleles[g] for g in variant.genotypes] + assert a1 == a2 + + def verify_tables_api_equality(self, ts): + for samples in [None, list(ts.samples()), ts.samples()]: + tables = ts.dump_tables() + tables.simplify(samples=samples) + tables.assert_equals( + ts.simplify(samples=samples).tables, ignore_timestamps=True + ) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_simplify_tables_equality(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + self.verify_tables_api_equality(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_simplify_provenance(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + self.verify_simplify_provenance(ts) + + # TODO this test needs to be broken up into discrete bits, so that we can + # test them independently. A way of getting a random-ish subset of samples + # from the pytest param would be useful. + @pytest.mark.slow + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_simplify(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + n = ts.num_samples + sample_sizes = {0} + if n > 1: + sample_sizes |= {1} + if n > 2: + sample_sizes |= {2, max(2, n // 2), n - 1} + for k in sample_sizes: + subset = random.sample(list(ts.samples()), k) + self.verify_simplify_topology(ts, subset) + self.verify_simplify_equality(ts, subset) + self.verify_simplify_variants(ts, subset) + + def test_simplify_bugs(self): + prefix = os.path.join(os.path.dirname(__file__), "data", "simplify-bugs") + j = 1 + while True: + nodes_file = os.path.join(prefix, f"{j:02d}-nodes.txt") + if not os.path.exists(nodes_file): + break + edges_file = os.path.join(prefix, f"{j:02d}-edges.txt") + sites_file = os.path.join(prefix, f"{j:02d}-sites.txt") + mutations_file = os.path.join(prefix, f"{j:02d}-mutations.txt") + with open(nodes_file) as nodes, open(edges_file) as edges, open( + sites_file + ) as sites, open(mutations_file) as mutations: + ts = tskit.load_text( + nodes=nodes, + edges=edges, + sites=sites, + mutations=mutations, + strict=False, + ) + samples = list(ts.samples()) + self.verify_simplify_equality(ts, samples) + j += 1 + assert j > 1 + + def test_simplify_migrations_fails(self): + ts = msprime.simulate( + population_configurations=[ + msprime.PopulationConfiguration(10), + msprime.PopulationConfiguration(10), + ], + migration_matrix=[[0, 1], [1, 0]], + random_seed=2, + record_migrations=True, + ) + assert ts.num_migrations > 0 + # We don't support simplify with migrations, so should fail. + with pytest.raises(_tskit.LibraryError): + ts.simplify() + + class TestMinMaxTime: def get_example_tree_sequence(self, use_unknown_time): """ diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 8133c9d7eb..31abe8b6e3 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -347,9 +347,29 @@ def test_simplify_bad_args(self): tc.simplify([0, 1], keep_input_roots="sdf") with pytest.raises(TypeError): tc.simplify([0, 1], filter_populations="x") + with pytest.raises(TypeError): + tc.simplify([0, 1], filter_nodes="x") with pytest.raises(_tskit.LibraryError): tc.simplify([0, -1]) + @pytest.mark.parametrize("value", [True, False]) + @pytest.mark.parametrize( + "flag", + [ + "filter_sites", + "filter_populations", + "filter_individuals", + "filter_nodes", + "reduce_to_site_topology", + "keep_unary", + "keep_unary_in_individuals", + "keep_input_roots", + ], + ) + def test_simplify_flags(self, flag, value): + tables = _tskit.TableCollection(1) + tables.simplify([], **{flag: value}) + def test_link_ancestors_bad_args(self): ts = msprime.simulate(10, random_seed=1) tc = ts.tables._ll_tables diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index dae69c6f85..919f2309ef 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -3029,8 +3029,8 @@ def test_full_samples(self): def test_bad_samples(self): n = 10 ts = msprime.simulate(n, random_seed=self.random_seed) - tables = ts.dump_tables() - for bad_node in [-1, n, n + 1, ts.num_nodes - 1, ts.num_nodes, 2**31 - 1]: + for bad_node in [-1, ts.num_nodes, 2**31 - 1]: + tables = ts.dump_tables() with pytest.raises(_tskit.LibraryError): tables.simplify(samples=[0, bad_node]) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index ac1368e921..356c3fb316 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -2686,7 +2686,7 @@ def verify_simplify( filter_sites=filter_sites, keep_input_roots=keep_input_roots, filter_nodes=filter_nodes, - compare_lib=False, # TMP + compare_lib=True, # TMP ) if debug: print("before") @@ -4799,18 +4799,10 @@ def do_simplify( ) py_tables = new_ts.dump_tables() - for lib_tables, lib_node_map in [ - (lib_tables1, lib_node_map1), - (lib_tables2, lib_node_map2), - ]: - assert lib_tables.nodes == py_tables.nodes - assert lib_tables.edges == py_tables.edges - assert lib_tables.migrations == py_tables.migrations - assert lib_tables.sites == py_tables.sites - assert lib_tables.mutations == py_tables.mutations - assert lib_tables.individuals == py_tables.individuals - assert lib_tables.populations == py_tables.populations - assert all(node_map == lib_node_map) + py_tables.assert_equals(lib_tables1, ignore_provenance=True) + py_tables.assert_equals(lib_tables2, ignore_provenance=True) + assert all(node_map == lib_node_map1) + assert all(node_map == lib_node_map2) return new_ts, node_map @@ -5663,6 +5655,96 @@ def test_many_trees_recurrent_mutations_internal_samples(self): self.verify_simplify_haplotypes(ts, samples, keep_unary=keep) +class TestSimplifyUnreferencedPopulations: + def example(self): + tables = tskit.TableCollection(1) + tables.populations.add_row() + tables.populations.add_row() + # No references to population 0 + tables.nodes.add_row(time=0, population=1, flags=1) + tables.nodes.add_row(time=0, population=1, flags=1) + tables.nodes.add_row(time=1, population=1, flags=0) + # Unreference node + tables.nodes.add_row(time=1, population=1, flags=0) + tables.edges.add_row(0, 1, parent=2, child=0) + tables.edges.add_row(0, 1, parent=2, child=1) + tables.sort() + return tables + + def test_no_filter_populations(self): + tables = self.example() + tables.simplify(filter_populations=False) + assert len(tables.populations) == 2 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.population == 1) + + def test_no_filter_populations_nodes(self): + tables = self.example() + tables.simplify(filter_populations=False, filter_nodes=False) + assert len(tables.populations) == 2 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.population == 1) + + def test_filter_populations_no_filter_nodes(self): + tables = self.example() + tables.simplify(filter_populations=True, filter_nodes=False) + assert len(tables.populations) == 1 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.population == 0) + + def test_remapped_default(self): + tables = self.example() + tables.simplify() + assert len(tables.populations) == 1 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.population == 0) + + +class TestSimplifyUnreferencedIndividuals: + def example(self): + tables = tskit.TableCollection(1) + tables.individuals.add_row() + tables.individuals.add_row() + # No references to individual 0 + tables.nodes.add_row(time=0, individual=1, flags=1) + tables.nodes.add_row(time=0, individual=1, flags=1) + tables.nodes.add_row(time=1, individual=1, flags=0) + # Unreference node + tables.nodes.add_row(time=1, individual=1, flags=0) + tables.edges.add_row(0, 1, parent=2, child=0) + tables.edges.add_row(0, 1, parent=2, child=1) + tables.sort() + return tables + + def test_no_filter_individuals(self): + tables = self.example() + tables.simplify(filter_individuals=False) + assert len(tables.individuals) == 2 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.individual == 1) + + def test_no_filter_individuals_nodes(self): + tables = self.example() + tables.simplify(filter_individuals=False, filter_nodes=False) + assert len(tables.individuals) == 2 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.individual == 1) + + def test_filter_individuals_no_filter_nodes(self): + tables = self.example() + tables.simplify(filter_individuals=True, filter_nodes=False) + assert len(tables.individuals) == 1 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.individual == 0) + + def test_remapped_default(self): + tables = self.example() + tables.simplify() + assert len(tables.individuals) == 1 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.individual == 0) + + class TestSimplifyKeepInputRoots(SimplifyTestBase, ExampleTopologyMixin): """ Tests for the keep_input_roots option to simplify. @@ -5847,7 +5929,7 @@ def verify_nodes_unchanged(self, ts_in, resample_size=None, **kwargs): for ts in (ts_in, self.reverse_node_indexes(ts_in)): filtered, n_map = do_simplify( - ts, samples=samples, filter_nodes=False, compare_lib=False, **kwargs + ts, samples=samples, filter_nodes=False, compare_lib=True, **kwargs ) assert np.array_equal(n_map, np.arange(ts.num_nodes, dtype=n_map.dtype)) referenced_nodes = set(filtered.samples()) @@ -5855,12 +5937,9 @@ def verify_nodes_unchanged(self, ts_in, resample_size=None, **kwargs): referenced_nodes.update(filtered.edges_child) for n1, n2 in zip(ts.nodes(), filtered.nodes()): # Ignore the tskit.NODE_IS_SAMPLE flag which can be changed by simplify - if n2.id in referenced_nodes: - assert n_map[n2.id] == tskit.NULL - else: - n1 = n1.replace(flags=n1.flags | tskit.NODE_IS_SAMPLE) - n2 = n2.replace(flags=n2.flags | tskit.NODE_IS_SAMPLE) - assert n1 == n2 + n1 = n1.replace(flags=n1.flags | tskit.NODE_IS_SAMPLE) + n2 = n2.replace(flags=n2.flags | tskit.NODE_IS_SAMPLE) + assert n1 == n2 # Check that edges are identical to the normal simplify(), # with the normal "simplify" having altered IDs @@ -5971,6 +6050,46 @@ def test_keeping_unary(self): self.verify_nodes_unchanged(ts_with_unary, keep_unary=True) self.verify_nodes_unchanged(ts_with_unary, keep_unary=False) + def test_find_unreferenced_nodes(self): + # Simple test to show we can find unreferenced nodes easily. + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts1 = tskit.Tree.generate_balanced(4).tree_sequence + ts2, node_map = do_simplify( + ts1, + [0, 1, 2], + filter_nodes=False, + ) + assert np.array_equal(node_map, np.arange(ts1.num_nodes)) + node_references = np.zeros(ts1.num_nodes, dtype=np.int32) + node_references[ts2.edges_parent] += 1 + node_references[ts2.edges_child] += 1 + # Simplifying for [0, 1, 2] should remove references to node 3 and 5 + assert list(node_references) == [1, 1, 1, 0, 2, 0, 1] + + def test_mutations_on_removed_branches(self): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + # A mutation on a removed branch should get removed + tables.sites.add_row(0.5, "A") + tables.mutations.add_row(0, node=3, derived_state="T") + ts2, node_map = do_simplify( + tables.tree_sequence(), + [0, 1, 2], + filter_nodes=False, + ) + assert ts2.num_sites == 0 + assert ts2.num_mutations == 0 + class TestMapToAncestors: """ diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 483186ffc6..b875115621 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3358,15 +3358,26 @@ def simplify( """ Simplifies the tables in place to retain only the information necessary to reconstruct the tree sequence describing the given ``samples``. - If ``filter_nodes`` is True, this can change the ID of the nodes, so - that the node ``samples[k]`` will have ID ``k`` in the result, resulting - in a NodeTable where only the first ``len(samples)`` nodes are marked - as samples. The mapping from node IDs in the current set of tables to - their equivalent values in the simplified tables is also returned as a - numpy array. If an array ``a`` is returned by this function and ``u`` - is the ID of a node in the input table, then ``a[u]`` is the ID of this - node in the output table. For any node ``u`` that is not mapped into - the output tables, this mapping will equal ``-1``. + If ``filter_nodes`` is True (the default), this can change the ID of + the nodes, so that the node ``samples[k]`` will have ID ``k`` in the + result, resulting in a NodeTable where only the first ``len(samples)`` + nodes are marked as samples. The mapping from node IDs in the current + set of tables to their equivalent values in the simplified tables is + also returned as a numpy array. If an array ``a`` is returned by this + function and ``u`` is the ID of a node in the input table, then + ``a[u]`` is the ID of this node in the output table. For any node ``u`` + that is not mapped into the output tables, this mapping will equal + ``-1``. + + If ``filter_nodes`` is False, then the output node table will be + unchanged except for updating the sample status of nodes. Nodes that + are in the specified list of ``samples`` will be marked as samples + in the output, and nodes that are currently marked as samples in + the node table but **not** in the specified list of ``samples`` + will have their sample flag cleared. Note also that the order of + the ``samples`` list is not meaningful when ``filter_nodes`` is False. + The returned node mapping is always the identity mapping, such that + ``a[u] == u`` for all nodes. Tables operated on by this function must: be sorted (see :meth:`TableCollection.sort`), have children be born strictly after their diff --git a/python/tskit/trees.py b/python/tskit/trees.py index ba1d109e03..f854947392 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6506,14 +6506,14 @@ def simplify( original tree sequence, or :data:`tskit.NULL` (-1) if ``u`` is no longer present in the simplified tree sequence. - In the returned tree sequence the only nodes flagged as samples are those - passed as ``samples``: all others are not flagged as samples. - If ``filter_nodes`` is not False, nodes in the returned tree sequence are - also reordered such that the node with ID ``0`` corresponds to ``samples[0]``, - node ``1`` corresponds to ``samples[1]`` etc., and the remaining node IDs - are allocated sequentially in time order. Alternatively, if ``filter_nodes`` - is False, the node order is not changed, and the order of IDs passed to - ``samples`` is irrelevant. + In the returned tree sequence the only nodes flagged as samples are + those passed as ``samples``: all others are not flagged as samples. If + ``filter_nodes`` is True (the default), nodes in the returned tree + sequence are also reordered such that the node with ID ``0`` + corresponds to ``samples[0]``, node ``1`` corresponds to ``samples[1]`` + etc., and the remaining node IDs are allocated sequentially in time + order. Alternatively, if ``filter_nodes`` is False, the node order is + not changed, and the order of IDs passed to ``samples`` is irrelevant. If you wish to simplify a set of tables that do not satisfy all requirements for building a TreeSequence, then use From 1167966ccdc3facf21e61c13b155389c6ffbc390 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 23 Dec 2022 09:35:54 +0000 Subject: [PATCH 09/84] Document root_threshold And note issue with large numbers of roots in mssing regions. Fixes #2629 --- docs/data-model.md | 67 +++++++++++++++++++++++++++++-------------- python/tskit/trees.py | 9 +++++- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/docs/data-model.md b/docs/data-model.md index 20b8957175..3339f425cb 100644 --- a/docs/data-model.md +++ b/docs/data-model.md @@ -830,12 +830,13 @@ HTML(html_quintuple_table(ts, show_convenience_arrays=True)) ### Roots -The roots of a tree are defined as the unique endpoints of upward paths -starting from sample nodes ({ref}`isolated` -sample nodes also count as roots). Thus, trees can have multiple roots in `tskit`. -For example, if we delete the edge joining `6` and `7` in the previous -example, we get a tree with two roots: - +In the `tskit` {class}`trees ` we have shown so far, all the sample nodes have +been connected to each other. This means each tree has only a single {attr}`~Tree.root` +(i.e. the oldest node found when tracing a path backwards in time from any sample). +However, a tree can contain {ref}`sec_data_model_tree_isolated_sample_nodes` +or unconnected topologies, and can therefore have *multiple* {attr}`~Tree.roots`. +Here's an example, created by deleting the edge joining `6` and `7` in the tree sequence +used above: ```{code-cell} ipython3 :tags: ["hide-input"] @@ -845,7 +846,7 @@ ts_multiroot = tables.tree_sequence() SVG(ts_multiroot.first().draw_svg(time_scale="rank")) ``` -Note that in tree sequence terminology, this should *not* be thought +In `tskit` terminology, this should *not* be thought of as two separate trees, but as a single multi-root "tree", comprising two unlinked topologies. This fits with the definition of a tree in a tree sequence: a tree describes the ancestry of the same @@ -853,19 +854,34 @@ fixed set of sample nodes at a single position in the genome. In the picture above, *both* the left and right hand topologies are required to describe the genealogy of samples 0..4 at this position. -Here's what it looks like for an entire tree sequence: +Here's what the entire tree sequence now looks like: ```{code-cell} ipython3 :tags: ["hide-input"] SVG(ts_multiroot.draw_svg(time_scale="rank")) ``` -This tree sequence consists of three trees. The first tree, which applies from -position 0 to 20, is the one used in our example. As we saw, removing the edge -connecting node 6 to node 7 has created a tree with 2 roots (and thus 2 -unconnected topologies in a single tree). In contrast, the second tree, from -position 20 to 40, has a single root. Finally the third tree, from position -40 to 60, again has two roots. +From the terminology above, it can be seen that this tree sequence consists of only +three trees (not five). The first tree, which applies from position 0 to 20, is the one +used in our example. As we saw, removing the edge connecting node 6 to node 7 has +created a tree with 2 roots (and thus 2 unconnected topologies in a single tree). +In contrast, the second tree, from position 20 to 40, has a single root. Finally the +third tree, from position 40 to 60, again has two roots. + +(sec_data_model_tree_root_threshold)= + +#### The root threshold + +The roots of a tree are defined by reference to the +{ref}`sample nodes`. By default, roots are the unique +endpoints of the paths traced upwards from the sample nodes; equivalently, each root +counts one or more samples among its descendants (or is itself a sample node). This is +the case when the {attr}`~Tree.root_threshold` property of a tree is left at its default +value of `1`. If, however, the `root_threshold` is (say) `2`, then a node is +considered a root only if it counts at least two samples among its descendants. Setting +an alternative `root_threshold` value can be used to avoid visiting +{ref}`sec_data_model_tree_isolated_sample_nodes`, for example when dealing with trees +containing {ref}`sec_data_model_missing_data`. (sec_data_model_tree_virtual_root)= @@ -940,11 +956,18 @@ for tree in ts_multiroot.trees(): ) ``` -However, it is also possible for a {ref}`sample node` -to be isolated. Unlike other nodes, isolated *sample* nodes are still considered as -being present on the tree (meaning they will still returned by the {meth}`Tree.nodes` -and {meth}`Tree.samples` methods): they are therefore plotted, but unconnected to any -other nodes. To illustrate, we can remove the edge from node 2 to node 7. + +(sec_data_model_tree_isolated_sample_nodes)= + +#### Isolated sample nodes + +It is also possible for a {ref}`sample node` +to be isolated. As long as the {ref}`root threshold` +is set to its default value, an isolated *sample* node will count as a root, and +therefore be considered as being present on the tree (meaning it will be +returned by the {meth}`Tree.nodes` +and {meth}`Tree.samples` methods). When displaying a tree, isolated samples are shown +unconnected to other nodes. To illustrate, we can remove the edge from node 2 to node 7: ```{code-cell} ipython3 :tags: ["hide-input"] @@ -955,9 +978,9 @@ ts_isolated = tables.tree_sequence() SVG(ts_isolated.draw_svg(time_scale="rank")) ``` -The rightmost tree now contains an isolated sample node (node 2). Isolated -sample nodes count as one of the {ref}`sec_data_model_tree_roots` of the tree, -so that tree has three roots, one of which is node 2: +The rightmost tree now contains an isolated sample node (node 2), which counts as +one of the {ref}`sec_data_model_tree_roots` of the tree. This tree therefore has three +roots, one of which is node 2: ```{code-cell} ipython3 rightmost_tree = ts_isolated.at_index(-1) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index f854947392..c86293c6d9 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -727,7 +727,8 @@ def tree_sequence(self): def root_threshold(self): """ Returns the minimum number of samples that a node must be an ancestor - of to be considered a potential root. + of to be considered a potential root. This can be set, for example, when + calling the :meth:`TreeSequence.trees` iterator. :return: The root threshold. :rtype: :class:`TreeSequence` @@ -1548,6 +1549,12 @@ def roots(self): Only requires O(number of roots) time. + .. note:: + In trees with large amounts of :ref:`sec_data_model_missing_data`, + for example where a region of the genome lacks any ancestral information, + there can be a very large number of roots, potentially all the samples + in the tree sequence. + :return: The list of roots in this tree. :rtype: list """ From 8aea74e923ea11a573627b7025d783410b0bbfbf Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 23 Dec 2022 19:48:23 +0000 Subject: [PATCH 10/84] Implement is_root Fixes #2620 --- python/CHANGELOG.rst | 3 +++ python/tests/test_highlevel.py | 15 +++++++++++++++ python/tskit/trees.py | 17 +++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 0b3431c467..c3dc69c1e5 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,9 @@ **Features** +- A new ``Tree.is_root`` method avoids the need to to search the potentially + large list of ``Tree.roots`` (:user:`hyanwong`, :pr:`2669`, :issue:`2620`) + - The ``TreeSequence`` object now has the attributes ``min_time`` and ``max_time``, which are the minimum and maximum among the node times and mutation times, respectively. (:user:`szhan`, :pr:`2612`, :issue:`2271`) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 65d17cd9f2..3c62fa476a 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -3899,6 +3899,21 @@ def test_simple_root_threshold(self): tree = tskit.Tree.generate_balanced(3, root_threshold=4) assert tree.num_roots == 0 + @pytest.mark.parametrize("root_threshold", [1, 2, 3]) + def test_is_root(self, root_threshold): + # Make a tree with multiple roots with different numbers of samples under each + ts = tskit.Tree.generate_balanced(5).tree_sequence + ts = ts.decapitate(ts.max_root_time - 0.1) + tables = ts.dump_tables() + tables.nodes.add_row(flags=0) # Isolated non-sample + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE) # Isolated sample + ts = tables.tree_sequence() + assert {ts.first().num_samples(u) for u in ts.first().roots} == {1, 2, 3} + tree = ts.first(root_threshold=root_threshold) + roots = set(tree.roots) + for u in range(ts.num_nodes): # Will also test isolated nodes + assert tree.is_root(u) == (u in roots) + def test_is_descendant(self): def is_descendant(tree, u, v): path = [] diff --git a/python/tskit/trees.py b/python/tskit/trees.py index c86293c6d9..60fe9699e5 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1583,6 +1583,23 @@ def root(self): raise ValueError("More than one root exists. Use tree.roots instead") return self.left_root + def is_root(self, u) -> bool: + """ + Returns ``True`` if the specified node is a root in this tree (see + :attr:`~Tree.roots` for the definition of a root). This is exactly equivalent to + finding the node ID in :attr:`~Tree.roots`, but is more efficient for trees + with large numbers of roots, such as in regions with extensive + :ref:`sec_data_model_missing_data`. Note that ``False`` is returned for all + other nodes, including :ref:`isolated` + non-sample nodes which are not found in the topology of the current tree. + + :param int u: The node of interest. + :return: ``True`` if u is a root. + """ + return ( + self.num_samples(u) >= self.root_threshold and self.parent(u) == tskit.NULL + ) + def get_index(self): # Deprecated alias for self.index return self.index From 78ccda7727b9494a2fb4c67e71461f1e06f37087 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 1 Nov 2022 19:32:55 +0000 Subject: [PATCH 11/84] Support Py3.11 --- .github/workflows/docker/shared.env | 1 + .github/workflows/docs.yml | 13 ++++++----- .github/workflows/tests.yml | 8 +++++-- .github/workflows/wheels.yml | 22 ++++++++++--------- .mergify.yml | 9 ++++++++ python/requirements/CI-docs/requirements.txt | 2 +- .../CI-tests-conda/requirements.txt | 1 + .../CI-tests-pip/requirements.txt | 6 ++--- python/setup.cfg | 1 + 9 files changed, 42 insertions(+), 21 deletions(-) diff --git a/.github/workflows/docker/shared.env b/.github/workflows/docker/shared.env index 6f613fb81c..26a8ea4318 100644 --- a/.github/workflows/docker/shared.env +++ b/.github/workflows/docker/shared.env @@ -1,4 +1,5 @@ PYTHON_VERSIONS=( + cp311-cp311 cp310-cp310 cp39-cp39 cp38-cp38 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 2da3ccc3b0..201ecd3a4e 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -19,29 +19,32 @@ env: jobs: build-deploy-docs: name: Docs - runs-on: ubuntu-18.04 + runs-on: ubuntu-latest steps: - name: Cancel Previous Runs uses: styfle/cancel-workflow-action@0.6.0 with: access_token: ${{ github.token }} - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: python-version: 3.8 - - uses: actions/cache@v2 + - uses: actions/cache@v3 id: cache with: path: venv - key: docs-venv-v2-${{ hashFiles(env.REQUIREMENTS) }} + key: docs-venv-v4-${{ hashFiles(env.REQUIREMENTS) }} - name: Build virtualenv if: steps.cache.outputs.cache-hit != 'true' run: python -m venv venv + - name: Downgrade pip + run: venv/bin/activate && pip install pip==20.0.2 + - name: Install deps run: venv/bin/activate && pip install -r ${{env.REQUIREMENTS}} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8c617bba7e..bec16ccd64 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,7 +61,7 @@ jobs: strategy: fail-fast: false matrix: - python: [ 3.7, 3.9, "3.10" ] + python: [ 3.7, 3.9, "3.11" ] os: [ macos-latest, ubuntu-latest, windows-latest ] defaults: run: @@ -143,6 +143,11 @@ jobs: conda activate anaconda-client-env python setup.py build_ext --inplace + - name: Remove py311 incompatible tests (lack of numba support for 3.11, needed for lshmm) + if: matrix.python == '3.11' + run: | + rm python/tests/test_genotype_matching_* + - name: Run tests working-directory: python run: | @@ -152,7 +157,6 @@ jobs: python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 tests - name: Upload coverage to Codecov - if: matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v2 with: working-directory: python diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index bdd48f89cd..81ab28f01e 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -15,12 +15,12 @@ jobs: runs-on: macos-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] steps: - name: Checkout uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install deps @@ -55,7 +55,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] wordsize: [64] steps: - name: Checkout @@ -108,7 +108,7 @@ jobs: uses: actions/checkout@v2 - name: Set up Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.8 @@ -140,14 +140,14 @@ jobs: runs-on: macos-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] steps: - name: Download wheels uses: actions/download-artifact@v2 with: name: osx-wheel-${{ matrix.python }} - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install wheel and test @@ -162,7 +162,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] wordsize: [64] steps: - name: Download wheels @@ -170,7 +170,7 @@ jobs: with: name: win-wheel-${{ matrix.python }}-${{ matrix.wordsize }} - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install wheel and test @@ -186,7 +186,7 @@ jobs: needs: ['manylinux'] strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] include: - python: 3.7 wheel: cp37 @@ -196,13 +196,15 @@ jobs: wheel: cp39 - python: "3.10" wheel: cp310 + - python: 3.11 + wheel: cp311 steps: - name: Download wheels uses: actions/download-artifact@v2 with: name: linux-wheels - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install wheel and test diff --git a/.mergify.yml b/.mergify.yml index 8a2b976d8b..2ab3025a55 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -7,10 +7,13 @@ queue_rules: - status-success=Lint - status-success=Python (3.7, macos-latest) - status-success=Python (3.9, macos-latest) + - status-success=Python (3.11, macos-latest) - status-success=Python (3.7, ubuntu-latest) - status-success=Python (3.9, ubuntu-latest) + - status-success=Python (3.11, ubuntu-latest) - status-success=Python (3.7, windows-latest) - status-success=Python (3.9, windows-latest) + - status-success=Python (3.11, windows-latest) - "status-success=ci/circleci: build" pull_request_rules: - name: Automatic rebase, CI and merge @@ -24,10 +27,13 @@ pull_request_rules: - status-success=Lint - status-success=Python (3.7, macos-latest) - status-success=Python (3.9, macos-latest) + - status-success=Python (3.11, macos-latest) - status-success=Python (3.7, ubuntu-latest) - status-success=Python (3.9, ubuntu-latest) + - status-success=Python (3.11, ubuntu-latest) - status-success=Python (3.7, windows-latest) - status-success=Python (3.9, windows-latest) + - status-success=Python (3.11, windows-latest) - "status-success=ci/circleci: build" #- status-success=codecov/patch #- status-success=codecov/project/c-tests @@ -59,10 +65,13 @@ pull_request_rules: - status-success=Lint - status-success=Python (3.7, macos-latest) - status-success=Python (3.9, macos-latest) + - status-success=Python (3.11, macos-latest) - status-success=Python (3.7, ubuntu-latest) - status-success=Python (3.9, ubuntu-latest) + - status-success=Python (3.11, ubuntu-latest) - status-success=Python (3.7, windows-latest) - status-success=Python (3.9, windows-latest) + - status-success=Python (3.11, windows-latest) - "status-success=ci/circleci: build" - "status-success=ci/circleci: build-32" - status-success=codecov/patch diff --git a/python/requirements/CI-docs/requirements.txt b/python/requirements/CI-docs/requirements.txt index d4a1f0c9fc..6278cfce2d 100644 --- a/python/requirements/CI-docs/requirements.txt +++ b/python/requirements/CI-docs/requirements.txt @@ -1,7 +1,7 @@ breathe==4.34.0 jupyter-book==0.13.1 h5py==3.7.0 -jsonschema==3.2.0 #jupyter-book 0.13.1 depends on jsonschema<4 +jsonschema[format-nongpl]==4.17.3 msprime==1.2.0 numpy==1.21.6 # Held at 1.21.6 for Python 3.7 compatibility PyGithub==1.55 diff --git a/python/requirements/CI-tests-conda/requirements.txt b/python/requirements/CI-tests-conda/requirements.txt index 5a632c2821..453d829799 100644 --- a/python/requirements/CI-tests-conda/requirements.txt +++ b/python/requirements/CI-tests-conda/requirements.txt @@ -1,3 +1,4 @@ msprime==1.2.0 kastore==0.3.2 jsonschema==4.16.0 +h5py==3.7.0 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index f168af3e01..f83b753a5d 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -1,9 +1,9 @@ -lshmm==0.0.4 -numpy==1.21.6 # Held at 1.21.6 for Python 3.7 compatibility +lshmm==0.0.4; python_version < '3.11' +numpy==1.21.6; python_version < '3.11' # Held at 1.21.6 for Python 3.7 compatibility +numpy==1.24.1; python_version > '3.10' pytest==7.1.3 pytest-cov==4.0.0 pytest-xdist==2.5.0 -h5py==3.7.0 svgwrite==1.4.3 portion==2.3.0 xmlunittest==0.5.0 diff --git a/python/setup.cfg b/python/setup.cfg index aca17749de..13a1506f25 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -20,6 +20,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Programming Language :: Python :: 3 :: Only Development Status :: 5 - Production/Stable Environment :: Other Environment From 476bd63a9716f1f83d70cfad199a700b4e0edfe0 Mon Sep 17 00:00:00 2001 From: Shing Zhan Date: Sun, 30 Oct 2022 19:13:16 +0000 Subject: [PATCH 12/84] Add Tree.siblings() --- .github/workflows/docs.yml | 2 +- .github/workflows/tests.yml | 2 +- python/CHANGELOG.rst | 6 ++++ python/tests/test_highlevel.py | 63 ++++++++++++++++++++++++++++++++++ python/tskit/trees.py | 27 +++++++++++++++ 5 files changed, 98 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 201ecd3a4e..83c36e4fd3 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -36,7 +36,7 @@ jobs: id: cache with: path: venv - key: docs-venv-v4-${{ hashFiles(env.REQUIREMENTS) }} + key: docs-venv-v5-${{ hashFiles(env.REQUIREMENTS) }} - name: Build virtualenv if: steps.cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bec16ccd64..4846fcc3f3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -157,7 +157,7 @@ jobs: python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 tests - name: Upload coverage to Codecov - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 with: working-directory: python fail_ci_if_error: true diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index c3dc69c1e5..1f33ba6d09 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -19,6 +19,12 @@ extra room on the canvas e.g. for long labels or repositioned graphical elements (:user:`hyanwong`, :pr:`2646`, :issue:`2645`) +- The ``Tree`` object now has the method ``siblings`` to get + the siblings of a node. It returns an empty tuple if the node + has no siblings, is not a node in the tree, is the virtual root, + or is an isolated non-sample node. + (:user:`szhan`, :pr:`2618`, :issue:`2616`) + **Breaking Changes** - the ``filter_populations``, ``filter_individuals``, and ``filter_sites`` diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 3c62fa476a..3e3a0de8cf 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -4317,6 +4317,69 @@ def test_node_edges(self): assert edge == tskit.NULL +class TestSiblings: + def test_balanced_binary_tree(self): + t = tskit.Tree.generate_balanced(num_leaves=3) + assert t.has_single_root + # Nodes 0 to 2 are leaves + for u in range(2): + assert t.is_leaf(u) + assert t.siblings(0) == (3,) + assert t.siblings(1) == (2,) + assert t.siblings(2) == (1,) + # Node 3 is the internal node + assert t.is_internal(3) + assert t.siblings(3) == (0,) + # Node 4 is the root + assert 4 == t.root + assert t.siblings(4) == tuple() + # Node 5 is the virtual root + assert 5 == t.virtual_root + assert t.siblings(5) == tuple() + + def test_star(self): + t = tskit.Tree.generate_star(num_leaves=3) + assert t.has_single_root + # Nodes 0 to 2 are leaves + for u in range(2): + assert t.is_leaf(u) + assert t.siblings(0) == (1, 2) + assert t.siblings(1) == (0, 2) + assert t.siblings(2) == (0, 1) + # Node 3 is the root + assert 3 == t.root + assert t.siblings(3) == tuple() + # Node 4 is the virtual root + assert 4 == t.virtual_root + assert t.siblings(4) == tuple() + + def test_multiroot_tree(self): + ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence + t = ts.decapitate(ts.node(5).time).first() + assert t.has_multiple_roots + # Nodes 0 to 3 are leaves + assert t.siblings(0) == (1,) + assert t.siblings(1) == (0,) + assert t.siblings(2) == (3,) + assert t.siblings(3) == (2,) + # Nodes 4 and 5 are both roots + assert 4 in t.roots + assert t.siblings(4) == (5,) + assert 5 in t.roots + assert t.siblings(5) == (4,) + # Node 7 is the virtual root + assert 7 == t.virtual_root + assert t.siblings(7) == tuple() + + @pytest.mark.parametrize("flag,expected", [(0, ()), (1, (2,))]) + def test_isolated_node(self, flag, expected): + tables = tskit.Tree.generate_balanced(2, arity=2).tree_sequence.dump_tables() + tables.nodes.add_row(flags=flag) # Add node 3 + t = tables.tree_sequence().first() + assert t.is_isolated(3) + assert t.siblings(3) == expected + + class TestNodeOrdering(HighLevelTestCase): """ Verify that we can use any node ordering for internal nodes diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 60fe9699e5..d560e29eec 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1208,6 +1208,33 @@ def right_sib_array(self): """ return self._right_sib_array + def siblings(self, u): + """ + Returns the sibling(s) of the specified node ``u`` as a tuple of integer + node IDs. If ``u`` has no siblings or is not a node in the current tree, + returns an empty tuple. If ``u`` is the root of a single-root tree, + returns an empty tuple; if ``u`` is the root of a multi-root tree, + returns the other roots (note all the roots are related by the virtual root). + If ``u`` is the virtual root (which has no siblings), returns an empty tuple. + If ``u`` is an isolated node, whether it has siblings or not depends on + whether it is a sample or non-sample node; if it is a sample node, + returns the root(s) of the tree, otherwise, returns an empty tuple. + The ordering of siblings is arbitrary and should not be depended on; + see the :ref:`data model ` section for details. + + :param int u: The node of interest. + :return: The siblings of ``u``. + :rtype: tuple(int) + """ + if u == self.virtual_root: + return tuple() + parent = self.parent(u) + if self.is_root(u): + parent = self.virtual_root + if parent != tskit.NULL: + return tuple(v for v in self.children(parent) if u != v) + return tuple() + @property def num_children_array(self): """ From 204cb25ed2cfcd250ad6a8835f748a2399ca97e3 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 13 Jan 2023 00:02:11 +0000 Subject: [PATCH 13/84] Clear docs CI cache --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 83c36e4fd3..c2fcc823e5 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -36,7 +36,7 @@ jobs: id: cache with: path: venv - key: docs-venv-v5-${{ hashFiles(env.REQUIREMENTS) }} + key: docs-venv-v6-${{ hashFiles(env.REQUIREMENTS) }} - name: Build virtualenv if: steps.cache.outputs.cache-hit != 'true' From 1ffcd0686257c32a7f15de2828668dab86a1c9b0 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 12 Jan 2023 23:48:56 +0000 Subject: [PATCH 14/84] Add codecov token --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4846fcc3f3..abbbeae854 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -159,6 +159,7 @@ jobs: - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: + token: ${{ secrets.CODECOV_TOKEN }} working-directory: python fail_ci_if_error: true flags: python-tests From 0b92eed1de69fc0ace595aee45502d9acb57e4fc Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 10 Jan 2023 14:19:19 +0000 Subject: [PATCH 15/84] CODE IMPORT: copy intervals code from msprime The authors of this module and associated tests have agreed to release the code originally developed for the msprime package under the MIT licence (see tskit-dev#2636). Specifically: @jeromekelleher: "I approve the relicensing of this code originally developed in msprime from GPLv3 to MIT." @hyanwong: "I am happy to relicense all of my contributions to this code (originally developed in msprime) from GPLv3 to MIT." @awohns: "I also approve the relicensing of this code originally developed in msprime from GPLv3 to MIT." @grahamgower: "I approve the relicensing of this code originally developed in msprime from GPLv3 to MIT." @petrelharp: "I approve the license change!" --- python/tests/test_intervals.py | 845 +++++++++++++++++++++++++++++++++ python/tskit/intervals.py | 759 +++++++++++++++++++++++++++++ 2 files changed, 1604 insertions(+) create mode 100644 python/tests/test_intervals.py create mode 100644 python/tskit/intervals.py diff --git a/python/tests/test_intervals.py b/python/tests/test_intervals.py new file mode 100644 index 0000000000..8d39d1838a --- /dev/null +++ b/python/tests/test_intervals.py @@ -0,0 +1,845 @@ +# +# +# This file is part of msprime. +# +# msprime is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# msprime is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with msprime. If not, see . +# +""" +Test cases for the intervals module. +""" +import decimal +import fractions +import gzip +import io +import os +import pickle +import textwrap +import xml + +import msprime +import numpy as np +import pytest +from numpy.testing import assert_array_equal + + +class TestRateMapErrors: + @pytest.mark.parametrize( + ("position", "rate"), + [ + ([], []), + ([0], []), + ([0], [0]), + ([1, 2], [0]), + ([0, -1], [0]), + ([0, 1], [-1]), + ], + ) + def test_bad_input(self, position, rate): + with pytest.raises(ValueError): + msprime.RateMap(position=position, rate=rate) + + def test_zero_length_interval(self): + with pytest.raises(ValueError, match=r"at indexes \[2 4\]"): + msprime.RateMap(position=[0, 1, 1, 2, 2, 3], rate=[0, 0, 0, 0, 0]) + + def test_bad_length(self): + positions = np.array([0, 1, 2]) + rates = np.array([0, 1, 2]) + with pytest.raises(ValueError, match="one less entry"): + msprime.RateMap(position=positions, rate=rates) + + def test_bad_first_pos(self): + positions = np.array([1, 2, 3]) + rates = np.array([1, 1]) + with pytest.raises(ValueError, match="First position"): + msprime.RateMap(position=positions, rate=rates) + + def test_bad_rate(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, -1]) + with pytest.raises(ValueError, match="negative.*1"): + msprime.RateMap(position=positions, rate=rates) + + def test_bad_rate_with_missing(self): + positions = np.array([0, 1, 2]) + rates = np.array([np.nan, -1]) + with pytest.raises(ValueError, match="negative.*1"): + msprime.RateMap(position=positions, rate=rates) + + def test_read_only(self): + positions = np.array([0, 0.25, 0.5, 0.75, 1]) + rates = np.array([0.125, 0.25, 0.5, 0.75]) # 1 shorter than positions + rate_map = msprime.RateMap(position=positions, rate=rates) + assert np.all(rates == rate_map.rate) + assert np.all(positions == rate_map.position) + with pytest.raises(AttributeError): + rate_map.rate = 2 * rate_map.rate + with pytest.raises(AttributeError): + rate_map.position = 2 * rate_map.position + with pytest.raises(AttributeError): + rate_map.left = 1234 + with pytest.raises(AttributeError): + rate_map.right = 1234 + with pytest.raises(AttributeError): + rate_map.mid = 1234 + with pytest.raises(ValueError): + rate_map.rate[0] = 1 + with pytest.raises(ValueError): + rate_map.position[0] = 1 + with pytest.raises(ValueError): + rate_map.left[0] = 1 + with pytest.raises(ValueError): + rate_map.mid[0] = 1 + with pytest.raises(ValueError): + rate_map.right[0] = 1 + + +class TestGetRateAllKnown: + examples = [ + msprime.RateMap(position=[0, 1], rate=[0]), + msprime.RateMap(position=[0, 1], rate=[0.1]), + msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + msprime.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + msprime.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + msprime.RateMap(position=range(100), rate=range(99)), + ] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_mid(self, rate_map): + rate = rate_map.get_rate(rate_map.mid) + assert len(rate) == len(rate_map) + for j in range(len(rate_map)): + assert rate[j] == rate_map[rate_map.mid[j]] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_left(self, rate_map): + rate = rate_map.get_rate(rate_map.left) + assert len(rate) == len(rate_map) + for j in range(len(rate_map)): + assert rate[j] == rate_map[rate_map.left[j]] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_right(self, rate_map): + rate = rate_map.get_rate(rate_map.right[:-1]) + assert len(rate) == len(rate_map) - 1 + for j in range(len(rate_map) - 1): + assert rate[j] == rate_map[rate_map.right[j]] + + +class TestOperations: + examples = [ + msprime.RateMap(position=[0, 1], rate=[0]), + msprime.RateMap(position=[0, 1], rate=[0.1]), + msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + msprime.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + msprime.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + msprime.RateMap(position=range(100), rate=range(99)), + # Missing data + msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0]), + msprime.RateMap(position=[0, 1, 2], rate=[0, np.nan]), + msprime.RateMap(position=[0, 1, 2, 3], rate=[0, np.nan, 1]), + ] + + @pytest.mark.parametrize("rate_map", examples) + def test_num_intervals(self, rate_map): + assert rate_map.num_intervals == len(rate_map.rate) + assert rate_map.num_missing_intervals == np.sum(np.isnan(rate_map.rate)) + assert rate_map.num_non_missing_intervals == np.sum(~np.isnan(rate_map.rate)) + + @pytest.mark.parametrize("rate_map", examples) + def test_mask_arrays(self, rate_map): + assert_array_equal(rate_map.missing, np.isnan(rate_map.rate)) + assert_array_equal(rate_map.non_missing, ~np.isnan(rate_map.rate)) + + @pytest.mark.parametrize("rate_map", examples) + def test_missing_intervals(self, rate_map): + missing = [] + for left, right, rate in zip(rate_map.left, rate_map.right, rate_map.rate): + if np.isnan(rate): + missing.append([left, right]) + if len(missing) == 0: + assert len(rate_map.missing_intervals()) == 0 + else: + assert_array_equal(missing, rate_map.missing_intervals()) + + @pytest.mark.parametrize("rate_map", examples) + def test_mean_rate(self, rate_map): + total_span = 0 + total_mass = 0 + for span, mass in zip(rate_map.span, rate_map.mass): + if not np.isnan(mass): + total_span += span + total_mass += mass + assert total_mass / total_span == rate_map.mean_rate + + @pytest.mark.parametrize("rate_map", examples) + def test_total_mass(self, rate_map): + assert rate_map.total_mass == np.nansum(rate_map.mass) + + @pytest.mark.parametrize("rate_map", examples) + def test_get_cumulative_mass(self, rate_map): + assert list(rate_map.get_cumulative_mass([0])) == [0] + assert list(rate_map.get_cumulative_mass([rate_map.sequence_length])) == [ + rate_map.total_mass + ] + assert_array_equal( + rate_map.get_cumulative_mass(rate_map.right), np.nancumsum(rate_map.mass) + ) + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate(self, rate_map): + assert_array_equal(rate_map.get_rate([0]), rate_map.rate[0]) + assert_array_equal( + rate_map.get_rate([rate_map.sequence_length - 1e-9]), rate_map.rate[-1] + ) + assert_array_equal(rate_map.get_rate(rate_map.left), rate_map.rate) + + @pytest.mark.parametrize("rate_map", examples) + def test_map_semantics(self, rate_map): + assert len(rate_map) == rate_map.num_non_missing_intervals + assert_array_equal(list(rate_map.keys()), rate_map.mid[rate_map.non_missing]) + for x in rate_map.left[rate_map.missing]: + assert x not in rate_map + for x in rate_map.mid[rate_map.missing]: + assert x not in rate_map + + +class TestFindIndex: + def test_one_interval(self): + rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + for j in range(10): + assert rate_map.find_index(j) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(9.999) == 0 + + def test_two_intervals(self): + rate_map = msprime.RateMap(position=[0, 5, 10], rate=[0.1, 0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(4.9999) == 0 + assert rate_map.find_index(5) == 1 + assert rate_map.find_index(5.1) == 1 + assert rate_map.find_index(7) == 1 + assert rate_map.find_index(9.999) == 1 + + def test_three_intervals(self): + rate_map = msprime.RateMap(position=[0, 5, 10, 15], rate=[0.1, 0.1, 0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(4.9999) == 0 + assert rate_map.find_index(5) == 1 + assert rate_map.find_index(5.1) == 1 + assert rate_map.find_index(7) == 1 + assert rate_map.find_index(9.999) == 1 + assert rate_map.find_index(10) == 2 + assert rate_map.find_index(10.1) == 2 + assert rate_map.find_index(12) == 2 + assert rate_map.find_index(14.9999) == 2 + + def test_out_of_bounds(self): + rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + for bad_value in [-1, -0.0001, 10, 10.0001, 1e9]: + with pytest.raises(KeyError, match="out of bounds"): + rate_map.find_index(bad_value) + + def test_input_types(self): + rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0) == 0 + assert rate_map.find_index(np.zeros(1)[0]) == 0 + + +class TestSimpleExamples: + def test_all_missing_one_interval(self): + with pytest.raises(ValueError, match="missing data"): + msprime.RateMap(position=[0, 10], rate=[np.nan]) + + def test_all_missing_two_intervals(self): + with pytest.raises(ValueError, match="missing data"): + msprime.RateMap(position=[0, 5, 10], rate=[np.nan, np.nan]) + + def test_count(self): + rate_map = msprime.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + assert rate_map.num_intervals == 2 + assert rate_map.num_missing_intervals == 1 + assert rate_map.num_non_missing_intervals == 1 + + def test_missing_arrays(self): + rate_map = msprime.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + assert list(rate_map.missing) == [True, False] + assert list(rate_map.non_missing) == [False, True] + + def test_missing_at_start_mean_rate(self): + positions = np.array([0, 0.5, 1, 2]) + rates = np.array([np.nan, 0, 1]) + rate_map = msprime.RateMap(position=positions, rate=rates) + assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) + + def test_missing_at_end_mean_rate(self): + positions = np.array([0, 1, 1.5, 2]) + rates = np.array([1, 0, np.nan]) + rate_map = msprime.RateMap(position=positions, rate=rates) + assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) + + def test_interval_properties_all_known(self): + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.left) == [0, 1, 2] + assert list(rate_map.right) == [1, 2, 3] + assert list(rate_map.mid) == [0.5, 1.5, 2.5] + assert list(rate_map.span) == [1, 1, 1] + assert list(rate_map.mass) == [0.1, 0.2, 0.3] + + def test_pickle_non_missing(self): + r1 = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + r2 = pickle.loads(pickle.dumps(r1)) + assert r1 == r2 + + def test_pickle_missing(self): + r1 = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, np.nan, 0.3]) + r2 = pickle.loads(pickle.dumps(r1)) + assert r1 == r2 + + def test_get_cumulative_mass_all_known(self): + rate_map = msprime.RateMap(position=[0, 10, 20, 30], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.mass) == [1, 2, 3] + assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 3, 6] + + def test_cumulative_mass_missing(self): + rate_map = msprime.RateMap(position=[0, 10, 20, 30], rate=[0.1, np.nan, 0.3]) + assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 1, 4] + + +class TestDisplay: + def test_str(self): + rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + s = """ + ┌──────────────────────────────────┐ + │left │right │ mid│ span│ rate│ + ├──────────────────────────────────┤ + │0 │10 │ 5│ 10│ 0.1│ + └──────────────────────────────────┘ + """ + assert textwrap.dedent(s) == str(rate_map) + + def test_str_scinot(self): + rate_map = msprime.RateMap(position=[0, 10], rate=[0.000001]) + s = """ + ┌───────────────────────────────────┐ + │left │right │ mid│ span│ rate│ + ├───────────────────────────────────┤ + │0 │10 │ 5│ 10│ 1e-06│ + └───────────────────────────────────┘ + """ + assert textwrap.dedent(s) == str(rate_map) + + def test_repr(self): + rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + s = "RateMap(position=array([ 0., 10.]), rate=array([0.1]))" + assert repr(rate_map) == s + + def test_repr_html(self): + rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + html = rate_map._repr_html_() + root = xml.etree.ElementTree.fromstring(html) + assert root.tag == "div" + table = root.find("table") + rows = list(table.find("tbody")) + assert len(rows) == 1 + + def test_long_table(self): + n = 100 + rate_map = msprime.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._display_table() + assert len(headers) == 5 + assert len(data) == 21 + # check some left values + assert int(data[0][0]) == 0 + assert int(data[-1][0]) == n - 1 + + def test_short_table(self): + n = 10 + rate_map = msprime.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._display_table() + assert len(headers) == 5 + assert len(data) == n + # check some left values. + assert int(data[0][0]) == 0 + assert int(data[-1][0]) == n - 1 + + +class TestRateMapIsMapping: + def test_items(self): + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + items = list(rate_map.items()) + assert items[0] == (0.5, 0.1) + assert items[1] == (1.5, 0.2) + assert items[2] == (2.5, 0.3) + + def test_keys(self): + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.keys()) == [0.5, 1.5, 2.5] + + def test_values(self): + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.values()) == [0.1, 0.2, 0.3] + + def test_in_points(self): + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + # Any point within the map are True + for x in [0, 0.5, 1, 2.9999]: + assert x in rate_map + # Points outside the map are False + for x in [-1, -0.0001, 3, 3.1]: + assert x not in rate_map + + def test_in_slices(self): + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + # slices that are within the map are "in" + for x in [slice(0, 0.5), slice(0, 1), slice(0, 2), slice(2, 3), slice(0, 3)]: + assert x in rate_map + # Any slice that doesn't fully intersect with the map "not in" + assert slice(-0.001, 1) not in rate_map + assert slice(0, 3.0001) not in rate_map + assert slice(2.9999, 3.0001) not in rate_map + assert slice(3, 4) not in rate_map + assert slice(-2, -1) not in rate_map + + def test_other_types_not_in(self): + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + for other_type in [None, "sdf", "123", {}, [], Exception]: + assert other_type not in rate_map + + def test_len(self): + rate_map = msprime.RateMap(position=[0, 1], rate=[0.1]) + assert len(rate_map) == 1 + rate_map = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert len(rate_map) == 2 + rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert len(rate_map) == 3 + + def test_immutable(self): + rate_map = msprime.RateMap(position=[0, 1], rate=[0.1]) + with pytest.raises(TypeError, match="item assignment"): + rate_map[0] = 1 + with pytest.raises(TypeError, match="item deletion"): + del rate_map[0] + + def test_eq(self): + r1 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + r2 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert r1 == r1 + assert r1 == r2 + r2 = msprime.RateMap(position=[0, 1, 3], rate=[0.1, 0.2]) + assert r1 != r2 + assert msprime.RateMap(position=[0, 1], rate=[0.1]) != msprime.RateMap( + position=[0, 1], rate=[0.2] + ) + assert msprime.RateMap(position=[0, 1], rate=[0.1]) != msprime.RateMap( + position=[0, 10], rate=[0.1] + ) + + def test_getitem_value(self): + rate_map = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert rate_map[0] == 0.1 + assert rate_map[0.5] == 0.1 + assert rate_map[1] == 0.2 + assert rate_map[1.5] == 0.2 + assert rate_map[1.999] == 0.2 + # Try other types + assert rate_map[np.array([1], dtype=np.float32)[0]] == 0.2 + assert rate_map[np.array([1], dtype=np.int32)[0]] == 0.2 + assert rate_map[np.array([1], dtype=np.float64)[0]] == 0.2 + assert rate_map[1 / 2] == 0.1 + assert rate_map[fractions.Fraction(1, 3)] == 0.1 + assert rate_map[decimal.Decimal(1)] == 0.2 + + def test_getitem_slice(self): + r1 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + # The semantics of the slice() function are tested elsewhere. + assert r1[:] == r1.copy() + assert r1[:] is not r1 + assert r1[1:] == r1.slice(left=1) + assert r1[:1.5] == r1.slice(right=1.5) + assert r1[0.5:1.5] == r1.slice(left=0.5, right=1.5) + + def test_getitem_slice_step(self): + r1 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + # Trying to set a "step" is a error + with pytest.raises(TypeError, match="interval slicing"): + r1[0:3:1] + + +class TestMappingMissingData: + def test_get_missing(self): + rate_map = msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + with pytest.raises(KeyError, match="within a missing interval"): + rate_map[0] + with pytest.raises(KeyError, match="within a missing interval"): + rate_map[0.999] + + def test_in_missing(self): + rate_map = msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + assert 0 not in rate_map + assert 0.999 not in rate_map + assert 1 in rate_map + + def test_keys_missing(self): + rate_map = msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + assert list(rate_map.keys()) == [1.5] + + +class TestGetIntermediates: + def test_get_rate(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = msprime.RateMap(position=positions, rate=rates) + assert np.all(rate_map.get_rate([0.5, 1.5]) == rates) + + def test_get_rate_out_of_bounds(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = msprime.RateMap(position=positions, rate=rates) + with pytest.raises(ValueError, match="out of bounds"): + rate_map.get_rate([1, -0.1]) + with pytest.raises(ValueError, match="out of bounds"): + rate_map.get_rate([2]) + + def test_get_cumulative_mass(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = msprime.RateMap(position=positions, rate=rates) + assert np.allclose(rate_map.get_cumulative_mass([0.5, 1.5]), np.array([0.5, 3])) + assert rate_map.get_cumulative_mass([2]) == rate_map.total_mass + + def test_get_bad_cumulative_mass(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = msprime.RateMap(position=positions, rate=rates) + with pytest.raises(ValueError, match="positions"): + rate_map.get_cumulative_mass([1, -0.1]) + with pytest.raises(ValueError, match="positions"): + rate_map.get_cumulative_mass([1, 2.1]) + + +class TestSlice: + def test_slice_no_params(self): + # test RateMap.slice(..., trim=False) + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice() + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + assert a == b + + def test_slice_left_examples(self): + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(left=50) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 0, 1, 2, 3], b.rate) + + b = a.slice(left=100) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, 3], b.rate) + + b = a.slice(left=150) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, 3], b.rate) + + def test_slice_right_examples(self): + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 300, 400], b.position) + assert_array_equal([0, 1, 2, np.nan], b.rate) + + b = a.slice(right=250) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 250, 400], b.position) + assert_array_equal([0, 1, 2, np.nan], b.rate) + + def test_slice_left_right_examples(self): + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(left=50, right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 0, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=250) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 250, 400], b.position) + assert_array_equal([np.nan, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=160) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 160, 400], b.position) + assert_array_equal([np.nan, 1, np.nan], b.rate) + + def test_slice_right_missing(self): + # If we take a right-slice into a trailing missing region, + # we should recover the same map. + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, np.nan]) + b = a.slice(right=350) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + b = a.slice(right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + def test_slice_left_missing(self): + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[np.nan, 1, 2, 3]) + b = a.slice(left=50) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + b = a.slice(left=100) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + def test_slice_with_floats(self): + # test RateMap.slice(..., trim=False) with floats + a = msprime.RateMap( + position=[np.pi * x for x in [0, 100, 200, 300, 400]], rate=[0, 1, 2, 3] + ) + b = a.slice(left=50 * np.pi) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50 * np.pi] + list(a.position[1:]), b.position) + assert_array_equal([np.nan] + list(a.rate), b.rate) + + def test_slice_trim_left(self): + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + b = a.slice(left=100, trim=True) + assert b == msprime.RateMap(position=[0, 100, 200, 300], rate=[2, 3, 4]) + b = a.slice(left=50, trim=True) + assert b == msprime.RateMap(position=[0, 50, 150, 250, 350], rate=[1, 2, 3, 4]) + + def test_slice_trim_right(self): + a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + b = a.slice(right=300, trim=True) + assert b == msprime.RateMap(position=[0, 100, 200, 300], rate=[1, 2, 3]) + b = a.slice(right=350, trim=True) + assert b == msprime.RateMap(position=[0, 100, 200, 300, 350], rate=[1, 2, 3, 4]) + + def test_slice_error(self): + recomb_map = msprime.RateMap(position=[0, 100], rate=[1]) + with pytest.raises(KeyError): + recomb_map.slice(left=-1) + with pytest.raises(KeyError): + recomb_map.slice(right=-1) + with pytest.raises(KeyError): + recomb_map.slice(left=200) + with pytest.raises(KeyError): + recomb_map.slice(right=200) + with pytest.raises(KeyError): + recomb_map.slice(left=20, right=10) + + +class TestReadHapmap: + def test_read_hapmap_simple(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 x 0 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x""" + ) + rm = msprime.RateMap.read_hapmap(hapfile) + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) + + def test_read_hapmap_from_filename(self, tmp_path): + with open(tmp_path / "hapfile.txt", "w") as hapfile: + hapfile.write( + """\ + HEADER + chr1 1 x 0 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x""" + ) + rm = msprime.RateMap.read_hapmap(tmp_path / "hapfile.txt") + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) + + @pytest.mark.filterwarnings("ignore:loadtxt") + def test_read_hapmap_empty(self): + hapfile = io.StringIO( + """\ + HEADER""" + ) + with pytest.raises(ValueError, match="Empty"): + msprime.RateMap.read_hapmap(hapfile) + + def test_read_hapmap_col_pos(self): + hapfile = io.StringIO( + """\ + HEADER + 0 0 + 0.000001 1 x + 0.000006 2 x x""" + ) + rm = msprime.RateMap.read_hapmap(hapfile, position_col=1, map_col=0) + assert_array_equal(rm.position, [0, 1, 2]) + assert np.allclose(rm.rate, [1e-8, 5e-8]) + + def test_read_hapmap_map_and_rate(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 0 0 + chr1 1 1 0.000001 x + chr1 2 2 0.000006 x x""" + ) + with pytest.raises(ValueError, match="both rate_col and map_col"): + msprime.RateMap.read_hapmap(hapfile, rate_col=2, map_col=3) + + def test_read_hapmap_duplicate_pos(self): + hapfile = io.StringIO( + """\ + HEADER + 0 0 + 0.000001 1 x + 0.000006 2 x x""" + ) + with pytest.raises(ValueError, match="same columns"): + msprime.RateMap.read_hapmap(hapfile, map_col=1) + + def test_read_hapmap_nonzero_rate_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 5 x + chr1 2 0 x x x""" + ) + rm = msprime.RateMap.read_hapmap(hapfile, rate_col=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert_array_equal(rm.rate, [np.nan, 5e-8]) + + def test_read_hapmap_nonzero_rate_end(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 5 x + chr1 2 1 x x x""" + ) + with pytest.raises(ValueError, match="last entry.*must be zero"): + msprime.RateMap.read_hapmap(hapfile, rate_col=2) + + def test_read_hapmap_gzipped(self, tmp_path): + hapfile = os.path.join(tmp_path, "hapmap.txt.gz") + with gzip.GzipFile(hapfile, "wb") as gzfile: + gzfile.write(b"HEADER\n") + gzfile.write(b"chr1 0 1\n") + gzfile.write(b"chr1 1 5.5\n") + gzfile.write(b"chr1 2 0\n") + rm = msprime.RateMap.read_hapmap(hapfile, rate_col=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert_array_equal(rm.rate, [1e-8, 5.5e-8]) + + def test_read_hapmap_nonzero_map_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 x 0.000001 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x x""" + ) + rm = msprime.RateMap.read_hapmap(hapfile) + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [1e-8, 0, 5e-8]) + + def test_read_hapmap_bad_nonzero_map_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0.0000005 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + with pytest.raises(ValueError, match="start.*must be zero"): + msprime.RateMap.read_hapmap(hapfile) + + def test_sequence_length(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + # test identical seq len + rm = msprime.RateMap.read_hapmap(hapfile, sequence_length=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert np.allclose(rm.rate, [1e-8, 5e-8]) + + hapfile.seek(0) + rm = msprime.RateMap.read_hapmap(hapfile, sequence_length=10) + assert_array_equal(rm.position, [0, 1, 2, 10]) + assert np.allclose(rm.rate, [1e-8, 5e-8, np.nan], equal_nan=True) + + def test_bad_sequence_length(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + with pytest.raises(ValueError, match="sequence_length"): + msprime.RateMap.read_hapmap(hapfile, sequence_length=1.999) + + def test_no_header(self): + data = """\ + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + hapfile_noheader = io.StringIO(data) + hapfile_header = io.StringIO("chr pos rate cM\n" + data) + with pytest.raises(ValueError): + msprime.RateMap.read_hapmap(hapfile_header, has_header=False) + rm1 = msprime.RateMap.read_hapmap(hapfile_header) + rm2 = msprime.RateMap.read_hapmap(hapfile_noheader, has_header=False) + assert_array_equal(rm1.rate, rm2.rate) + assert_array_equal(rm1.position, rm2.position) + + def test_hapmap_fragment(self): + hapfile = io.StringIO( + """\ + chr pos rate cM + 1 4283592 3.79115663174456 0 + 1 4361401 0.0664276817058413 0.294986106359414 + 1 7979763 10.9082897515584 0.535345505591925 + 1 8007051 0.0976780648822495 0.833010916332456 + 1 8762788 0.0899929572085616 0.906829844052373 + 1 9477943 0.0864382908650907 0.971188757364862 + 1 9696341 4.76495005895746 0.990066707213216 + 1 9752154 0.0864316558730679 1.25601286485381 + 1 9881751 0.0 1.26721414815999""" + ) + rm1 = msprime.RateMap.read_hapmap(hapfile) + hapfile.seek(0) + rm2 = msprime.RateMap.read_hapmap(hapfile, rate_col=2) + assert np.allclose(rm1.position, rm2.position) + assert np.allclose(rm1.rate, rm2.rate, equal_nan=True) diff --git a/python/tskit/intervals.py b/python/tskit/intervals.py new file mode 100644 index 0000000000..fe1baf7d75 --- /dev/null +++ b/python/tskit/intervals.py @@ -0,0 +1,759 @@ +# +# Copyright (C) 2020-2021 University of Oxford +# +# This file is part of msprime. +# +# msprime is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# msprime is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with msprime. If not, see . +# +""" +Utilities for working with intervals and interval maps. +""" +from __future__ import annotations + +import collections.abc +import itertools +import numbers +import warnings + +import numpy as np +from msprime import core + + +class RateMap(collections.abc.Mapping): + """ + A class mapping a non-negative rate value to a set of non-overlapping intervals + along the genome. Intervals for which the rate is unknown (i.e., missing data) + are encoded by NaN values in the ``rate`` array. + + :param list position: A list of :math:`n+1` positions, starting at 0, and ending + in the sequence length over which the RateMap will apply. + :param list rate: A list of :math:`n` positive rates that apply between each + position. Intervals with missing data are encoded by NaN values. + """ + + # The args are marked keyword only to give us some flexibility in how we + # create class this in the future. + def __init__( + self, + *, + position, + rate, + ): + # Making the arrays read-only guarantees rate and cumulative mass stay in sync + # We prevent the arrays themselves being overwritten by making self.position, + # etc properties. + + # TODO we always coerce the position type to float here, but we may not + # want to do this. int32 is a perfectly good choice a lot of the time. + self._position = np.array(position, dtype=float) + self._position.flags.writeable = False + self._rate = np.array(rate, dtype=float) + self._rate.flags.writeable = False + size = len(self._position) + if size < 2: + raise ValueError("Must have at least two positions") + if len(self._rate) != size - 1: + raise ValueError( + "Rate array must have one less entry than the position array" + ) + if self._position[0] != 0: + raise ValueError("First position must be zero") + + span = self.span + if np.any(span <= 0): + bad_pos = np.where(span <= 0)[0] + 1 + raise ValueError( + f"Position values not strictly increasing at indexes {bad_pos}" + ) + if np.any(self._rate < 0): + bad_rates = np.where(self._rate < 0)[0] + raise ValueError(f"Rate values negative at indexes {bad_rates}") + self._missing = np.isnan(self.rate) + self._num_missing_intervals = np.sum(self._missing) + if self._num_missing_intervals == len(self.rate): + raise ValueError("All intervals are missing data") + # We don't expose the cumulative mass array as a part of the array + # API is it's not quite as obvious how it lines up for each interval. + # It's really the sum of the mass up to but not including the current + # interval, which is a bit confusing. Probably best to just leave + # it as a function, so that people can sample at regular positions + # along the genome anyway, emphasising that it's a continuous function, + # not a step function like the other interval attributes. + self._cumulative_mass = np.insert(np.nancumsum(self.mass), 0, 0) + assert self._cumulative_mass[0] == 0 + self._cumulative_mass.flags.writeable = False + + @property + def left(self): + """ + The left position of each interval (inclusive). + """ + return self._position[:-1] + + @property + def right(self): + """ + The right position of each interval (exclusive). + """ + return self._position[1:] + + @property + def mid(self): + """ + Returns the midpoint of each interval. + """ + mid = self.left + self.span / 2 + mid.flags.writeable = False + return mid + + @property + def span(self): + """ + Returns the span (i.e., ``right - left``) of each of the intervals. + """ + span = self.right - self.left + span.flags.writeable = False + return span + + @property + def position(self): + """ + The breakpoint positions between intervals. This is equal to the + :attr:`~.RateMap.left` array with the :attr:`sequence_length` + appended. + """ + return self._position + + @property + def rate(self): + """ + The rate associated with each interval. Missing data is encoded + by NaN values. + """ + return self._rate + + @property + def mass(self): + r""" + The "mass" of each interval, defined as the :attr:`~.RateMap.rate` + :math:`\times` :attr:`~.RateMap.span`. This is NaN for intervals + containing missing data. + """ + return self._rate * self.span + + @property + def missing(self): + """ + A boolean array encoding whether each interval contains missing data. + Equivalent to ``np.isnan(rate_map.rate)`` + """ + return self._missing + + @property + def non_missing(self): + """ + A boolean array encoding whether each interval contains non-missing data. + Equivalent to ``np.logical_not(np.isnan(rate_map.rate))`` + """ + return ~self._missing + + # + # Interval counts + # + + @property + def num_intervals(self) -> int: + """ + The total number of intervals in this map. Equal to + :attr:`~.RateMap.num_missing_intervals` + + :attr:`~.RateMap.num_non_missing_intervals`. + """ + return len(self._rate) + + @property + def num_missing_intervals(self) -> int: + """ + Returns the number of missing intervals, i.e., those in which the + :attr:`~.RateMap.rate` value is a NaN. + """ + return self._num_missing_intervals + + @property + def num_non_missing_intervals(self) -> int: + """ + The number of non missing intervals, i.e., those in which the + :attr:`~.RateMap.rate` value is not a NaN. + """ + return self.num_intervals - self.num_missing_intervals + + @property + def sequence_length(self): + """ + The sequence length covered by this map + """ + return self.position[-1] + + @property + def total_mass(self): + """ + The cumulative total mass over the entire map. + """ + return self._cumulative_mass[-1] + + @property + def mean_rate(self): + """ + The mean rate over this map weighted by the span covered by each rate. + Unknown intervals are excluded. + """ + total_span = np.sum(self.span[self.non_missing]) + return self.total_mass / total_span + + def get_rate(self, x): + """ + Return the rate at the specified list of positions. + + .. note:: This function will return a NaN value for any positions + that contain missing data. + + :param numpy.ndarray x: The positions for which to return values. + :return: An array of rates, the same length as ``x``. + :rtype: numpy.ndarray + """ + loc = np.searchsorted(self.position, x, side="right") - 1 + if np.any(loc < 0) or np.any(loc >= len(self.rate)): + raise ValueError("position out of bounds") + return self.rate[loc] + + def get_cumulative_mass(self, x): + """ + Return the cumulative mass of the map up to (but not including) a + given point for a list of positions along the map. This is equal to + the integral of the rate from 0 to the point. + + :param numpy.ndarray x: The positions for which to return values. + + :return: An array of cumulative mass values, the same length as ``x`` + :rtype: numpy.ndarray + """ + x = np.array(x) + if np.any(x < 0) or np.any(x > self.sequence_length): + raise ValueError(f"Cannot have positions < 0 or > {self.sequence_length}") + return np.interp(x, self.position, self._cumulative_mass) + + def find_index(self, x: float) -> int: + """ + Returns the index of the interval that the specified position falls within, + such that ``rate_map.left[index] <= x < self.rate_map.right[index]``. + + :param float x: The position to search. + :return: The index of the interval containing this point. + :rtype: int + :raises: KeyError if the position is not contained in any of the intervals. + """ + if x < 0 or x >= self.sequence_length: + raise KeyError(f"Position {x} out of bounds") + index = np.searchsorted(self.position, x, side="left") + if x < self.position[index]: + index -= 1 + assert self.left[index] <= x < self.right[index] + return index + + def missing_intervals(self): + """ + Returns the left and right coordinates of the intervals containing + missing data in this map as a 2D numpy array + with shape (:attr:`~.RateMap.num_missing_intervals`, 2). Each row + of this returned array is therefore a ``left``, ``right`` tuple + corresponding to the coordinates of the missing intervals. + + :return: A numpy array of the coordinates of intervals containing + missing data. + :rtype: numpy.ndarray + """ + out = np.empty((self.num_missing_intervals, 2)) + out[:, 0] = self.left[self.missing] + out[:, 1] = self.right[self.missing] + return out + + def asdict(self): + return {"position": self.position, "rate": self.rate} + + # + # Dunder methods. We implement the Mapping protocol via __iter__, __len__ + # and __getitem__. We have some extra semantics for __getitem__, providing + # slice notation. + # + + def __iter__(self): + # The clinching argument for using mid here is that if we used + # left instead we would have + # RateMap([0, 1], [0.1]) == RateMap([0, 100], [0.1]) + # by the inherited definition of equality since the dictionary items + # would be equal. + # Similarly, we only return the midpoints of known intervals + # because NaN values are not equal, and we would need to do + # something to work around this. It seems reasonable that + # this high-level operation returns the *known* values only + # anyway. + yield from self.mid[self.non_missing] + + def __len__(self): + return np.sum(self.non_missing) + + def __getitem__(self, key): + if isinstance(key, slice): + if key.step is not None: + raise TypeError("Only interval slicing is supported") + return self.slice(key.start, key.stop) + if isinstance(key, numbers.Number): + index = self.find_index(key) + if np.isnan(self.rate[index]): + # To be consistent with the __iter__ definition above we + # don't consider these missing positions to be "in" the map. + raise KeyError(f"Position {key} is within a missing interval") + return self.rate[index] + # TODO we could implement numpy array indexing here and call + # to get_rate. Note we'd need to take care that we return a keyerror + # if the returned array contains any nans though. + raise KeyError("Key {key} not in map") + + def _display_table(self): + def format_row(left, right, mid, span, rate): + return [ + f"{left:.10g}", + f"{right:.10g}", + f"{mid:.10g}", + f"{span:.10g}", + f"{rate:.2g}", + ] + + def format_slice(start, end): + return list( + itertools.starmap( + format_row, + zip( + self.left[start:end], + self.right[start:end], + self.mid[start:end], + self.span[start:end], + self.rate[start:end], + ), + ) + ) + + if self.num_intervals < 40: + data = format_slice(0, None) + else: + data = format_slice(0, 10) + data.append(["⋯"] * 5) + data += format_slice(-10, None) + + return ["left", "right", "mid", "span", "rate"], data + + def __str__(self): + titles, data = self._display_table() + data = [[[item] for item in row] for row in data] + table = core.text_table( + caption="", + column_titles=[[title] for title in titles], + column_alignments="<<>>>", + data=data, + ) + return table + + def _repr_html_(self): + col_titles, data = self._display_table() + return core.html_table("", col_titles, data) + + def __repr__(self): + return f"RateMap(position={repr(self.position)}, rate={repr(self.rate)})" + + # + # Methods for building rate maps. + # + + def copy(self) -> RateMap: + """ + Returns a deep copy of this RateMap. + """ + # We take read-only copies of the arrays in the constructor anyway, so + # no need for copying. + return RateMap(position=self.position, rate=self.rate) + + def slice(self, left=None, right=None, *, trim=False) -> RateMap: # noqa: A003 + """ + Returns a subset of this rate map in the specified interval. + + :param float left: The left coordinate (inclusive) of the region to keep. + If ``None``, defaults to 0. + :param float right: The right coordinate (exclusive) of the region to keep. + If ``None``, defaults to the sequence length. + :param bool trim: If True, remove the flanking regions such that the + sequence length of the new rate map is ``right`` - ``left``. If ``False`` + (default), do not change the coordinate system and mark the flanking + regions as "unknown". + :return: A new RateMap instance + :rtype: RateMap + """ + left = 0 if left is None else left + right = self.sequence_length if right is None else right + if not (0 <= left < right <= self.sequence_length): + raise KeyError(f"Invalid slice: left={left}, right={right}") + + i = self.find_index(left) + j = i + np.searchsorted(self.position[i:], right, side="right") + if right > self.position[j - 1]: + j += 1 + + position = self.position[i:j].copy() + rate = self.rate[i : j - 1].copy() + position[0] = left + position[-1] = right + + if trim: + # Return trimmed map with changed coords + return RateMap(position=position - left, rate=rate) + + # Need to check regions before & after sliced region are filled out: + if left != 0: + if np.isnan(rate[0]): + position[0] = 0 # Extend + else: + rate = np.insert(rate, 0, np.nan) # Prepend + position = np.insert(position, 0, 0) + if right != self.position[-1]: + if np.isnan(rate[-1]): + position[-1] = self.sequence_length # Extend + else: + rate = np.append(rate, np.nan) # Append + position = np.append(position, self.position[-1]) + return RateMap(position=position, rate=rate) + + @staticmethod + def uniform(sequence_length, rate) -> RateMap: + """ + Create a uniform rate map + """ + return RateMap(position=[0, sequence_length], rate=[rate]) + + @staticmethod + def read_hapmap( + fileobj, + sequence_length=None, + *, + has_header=True, + position_col=None, + rate_col=None, + map_col=None, + ): + # Black barfs with an INTERNAL_ERROR trying to reformat this docstring, + # so we explicitly disable reformatting here. + # fmt: off + """ + Parses the specified file in HapMap format and returns a :class:`.RateMap`. + HapMap files must white-space-delimited, and by default are assumed to + contain a single header line (which is ignored). Each subsequent line + then contains a physical position (in base pairs) and either a genetic + map position (in centiMorgans) or a recombination rate (in centiMorgans + per megabase). The value in the rate column in a given line gives the + constant rate between the physical position in that line (inclusive) and the + physical position on the next line (exclusive). + By default, the second column of the file is taken + as the physical position and the fourth column is taken as the genetic + position, as seen in the following sample of the format:: + + Chromosome Position(bp) Rate(cM/Mb) Map(cM) + chr10 48232 0.1614 0.002664 + chr10 48486 0.1589 0.002705 + chr10 50009 0.159 0.002947 + chr10 52147 0.1574 0.003287 + ... + chr10 133762002 3.358 181.129345 + chr10 133766368 0.000 181.144008 + + In the example above, the first row has a nonzero genetic map position + (last column, cM), implying a nonzero recombination rate before that + position, that is assumed to extend to the start of the chromosome + (at position 0 bp). However, if the first line has a nonzero bp position + (second column) and a zero genetic map position (last column, cM), + then the recombination rate before that position is *unknown*, producing + :ref:`missing data `. + + .. note:: + The rows are all assumed to come from the same contig, and the + first column is currently ignored. Therefore if you have a single + file containing several contigs or chromosomes, you must must split + it up into multiple files, and pass each one separately to this + function. + + :param str fileobj: Filename or file to read. This is passed directly + to :func:`numpy.loadtxt`, so if the filename extension is .gz or .bz2, + the file is decompressed first + :param float sequence_length: The total length of the map. If ``None``, + then assume it is the last physical position listed in the file. + Otherwise it must be greater then or equal to the last physical + position in the file, and the region between the last physical position + and the sequence_length is padded with a rate of zero. + :param bool has_header: If True (default), assume the file has a header row + and ignore the first line of the file. + :param int position_col: The zero-based index of the column in the file + specifying the physical position in base pairs. If ``None`` (default) + assume an index of 1 (i.e. the second column). + :param int rate_col: The zero-based index of the column in the file + specifying the rate in cM/Mb. If ``None`` (default) do not use the rate + column, but calculate rates using the genetic map positions, as + specified in ``map_col``. If the rate column is used, the + interval from 0 to first physical position in the file is marked as + unknown, and the last value in the rate column must be zero. + :param int map_col: The zero-based index of the column in the file + specifying the genetic map position in centiMorgans. If ``None`` + (default), assume an index of 3 (i.e. the fourth column). If the first + genetic position is 0 the interval from position 0 to the first + physical position in the file is marked as unknown. Otherwise, act + as if an additional row, specifying physical position 0 and genetic + position 0, exists at the start of the file. + :return: A RateMap object. + :rtype: RateMap + """ + # fmt: on + column_defs = {} # column definitions passed to np.loadtxt + if rate_col is None and map_col is None: + # Default to map_col + map_col = 3 + elif rate_col is not None and map_col is not None: + raise ValueError("Cannot specify both rate_col and map_col") + if map_col is not None: + column_defs[map_col] = ("map", float) + else: + column_defs[rate_col] = ("rate", float) + position_col = 1 if position_col is None else position_col + if position_col in column_defs: + raise ValueError( + "Cannot specify the same columns for position_col and " + "rate_col or map_col" + ) + column_defs[position_col] = ("pos", int) + + column_names = [c[0] for c in column_defs.values()] + column_data = np.loadtxt( + fileobj, + skiprows=1 if has_header else 0, + dtype=list(column_defs.values()), + usecols=list(column_defs.keys()), + unpack=True, + ) + data = dict(zip(column_names, column_data)) + + if "map" not in data: + assert "rate" in data + if data["rate"][-1] != 0: + raise ValueError("The last entry in the 'rate' column must be zero") + pos_Mb = data["pos"] / 1e6 + map_pos = np.cumsum(data["rate"][:-1] * np.diff(pos_Mb)) + data["map"] = np.insert(map_pos, 0, 0) / 100 + else: + data["map"] /= 100 # Convert centiMorgans to Morgans + if len(data["map"]) == 0: + raise ValueError("Empty hapmap file") + + # TO DO: read in chrom name from col 0 and poss set as .name + # attribute on the RateMap + + physical_positions = data["pos"] + genetic_positions = data["map"] + start = physical_positions[0] + end = physical_positions[-1] + + if genetic_positions[0] > 0 and start == 0: + raise ValueError( + "The map distance at the start of the chromosome must be zero" + ) + if start > 0: + physical_positions = np.insert(physical_positions, 0, 0) + if genetic_positions[0] > 0: + # Exception for a map that starts > 0cM: include the start rate + # in the mean + start = 0 + genetic_positions = np.insert(genetic_positions, 0, 0) + + if sequence_length is not None: + if sequence_length < end: + raise ValueError( + "The sequence_length cannot be less than the last physical position " + f" ({physical_positions[-1]})" + ) + if sequence_length > end: + physical_positions = np.append(physical_positions, sequence_length) + genetic_positions = np.append(genetic_positions, genetic_positions[-1]) + + assert genetic_positions[0] == 0 + rate = np.diff(genetic_positions) / np.diff(physical_positions) + if start != 0: + rate[0] = np.nan + if end != physical_positions[-1]: + rate[-1] = np.nan + return RateMap(position=physical_positions, rate=rate) + + +class RecombinationMap: + """ + A RecombinationMap represents the changing rates of recombination + along a chromosome. This is defined via two lists of numbers: + ``positions`` and ``rates``, which must be of the same length. + Given an index j in these lists, the rate of recombination + per base per generation is ``rates[j]`` over the interval + ``positions[j]`` to ``positions[j + 1]``. Consequently, the first + position must be zero, and by convention the last rate value + is also required to be zero (although it is not used). + + .. important:: + This class is deprecated (but supported indefinitely); + please use the :class:`.RateMap` class in new code. + In particular, note that when specifying ``rates`` in the + the :class:`.RateMap` class we now require an array + of length :math:`n - 1` (this class requires an array + of length :math:`n` in which the last entry is zero). + + :param list positions: The positions (in bases) denoting the + distinct intervals where recombination rates change. These can + be floating point values. + :param list rates: The list of rates corresponding to the supplied + ``positions``. Recombination rates are specified per base, + per generation. + :param int num_loci: **This parameter is no longer supported.** + Must be either None (meaning a continuous genome of the + finest possible resolution) or be equal to ``positions[-1]`` + (meaning a discrete genome). Any other value will result in + an error. Please see the :ref:`sec_legacy_0x_genome_discretisation` + section for more information. + """ + + def __init__(self, positions, rates, num_loci=None, map_start=0): + # Used as an internal flag for the 0.x simulate() function. This allows + # us to emulate the discrete-sites behaviour of 0.x code. + self._is_discrete = num_loci == positions[-1] + if num_loci is not None and num_loci != positions[-1]: + raise ValueError( + "The RecombinationMap interface is deprecated and only " + "partially supported. If you wish to simulate a number of " + "discrete loci, you must set num_loci == the sequence length. " + "If you wish to simulate recombination process on as fine " + "a map as possible, please omit the num_loci parameter (or set " + "to None). Otherwise, num_loci is no longer supported and " + "the behaviour of msprime 0.x cannot be emulated. Please " + "consider upgrading your code to the version 1.x APIs." + ) + self.map = RateMap(position=positions, rate=rates[:-1]) + + @classmethod + def uniform_map(cls, length, rate, num_loci=None): + """ + Returns a :class:`.RecombinationMap` instance in which the recombination + rate is constant over a chromosome of the specified length. + The legacy ``num_loci`` option is no longer supported and should not be used. + + :param float length: The length of the chromosome. + :param float rate: The rate of recombination per unit of sequence length + along this chromosome. + :param int num_loci: This parameter is no longer supported. + """ + return cls([0, length], [rate, 0], num_loci=num_loci) + + @classmethod + def read_hapmap(cls, filename): + """ + Parses the specified file in HapMap format. + + .. warning:: + This method is deprecated, use the :meth:`.RateMap.read_hapmap` + method instead. + + :param str filename: The name of the file to be parsed. This may be + in plain text or gzipped plain text. + :return: A RecombinationMap object. + """ + warnings.warn( + "RecombinationMap.read_hapmap() is deprecated. " + "Use RateMap.read_hapmap() instead.", + FutureWarning, + ) + rate_map = RateMap.read_hapmap(filename, position_col=1, rate_col=2) + # Mark anything missing as 0 for backwards compatibility. This will + # ensure that simulate() never trims parts of the tree sequence. + rate = rate_map.rate.copy() + rate[rate_map.missing] = 0 + return cls(rate_map.position, np.append(rate, 0)) + + @property + def mean_recombination_rate(self): + """ + Return the weighted mean recombination rate + across all windows of the entire recombination map. + """ + return self.map.mean_rate + + def get_total_recombination_rate(self): + """ + Returns the effective recombination rate for this genetic map. + This is the weighted mean of the rates across all intervals. + """ + return self.map.total_mass + + def physical_to_genetic(self, x): + return self.map.get_cumulative_mass(x) + + def genetic_to_physical(self, genetic_x): + if self.map.total_mass == 0: + # If we have a zero recombination rate throughout then everything + # except L maps to 0. + return self.get_sequence_length() if genetic_x > 0 else 0 + if genetic_x == 0: + return self.map.position[0] + # TODO refactor this to this to use get_cumulative_mass() function / add the + # corresponding high-level function to the rate map. + index = np.searchsorted(self.map._cumulative_mass, genetic_x) - 1 + y = ( + self.map.position[index] + + (genetic_x - self.map._cumulative_mass[index]) / self.map.rate[index] + ) + return y + + def physical_to_discrete_genetic(self, physical_x): + raise ValueError("Discrete genetic space is no longer supported") + + def get_per_locus_recombination_rate(self): + raise ValueError("Genetic loci are no longer supported") + + def get_num_loci(self): + raise ValueError("num_loci is no longer supported") + + def get_size(self): + return len(self.map.position) + + def get_positions(self): + return list(self.map.position) + + def get_rates(self): + return list(self.map.rate) + [0] + + def get_sequence_length(self): + return self.map.sequence_length + + def get_length(self): + # Deprecated: use get_sequence_length() instead + return self.get_sequence_length() + + def asdict(self): + return self.map.asdict() From 4ebde51605cd06f37f4f5cce3e965cde35cecbc0 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 10 Jan 2023 14:30:08 +0000 Subject: [PATCH 16/84] Update copyright headers and licence text --- python/tests/test_intervals.py | 30 ++++++++++++++++++------------ python/tskit/intervals.py | 29 +++++++++++++++++------------ 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/python/tests/test_intervals.py b/python/tests/test_intervals.py index 8d39d1838a..a4ac6cf984 100644 --- a/python/tests/test_intervals.py +++ b/python/tests/test_intervals.py @@ -1,19 +1,25 @@ +# MIT License # +# Copyright (c) 2023 Tskit Developers +# Copyright (C) 2020-2021 University of Oxford # -# This file is part of msprime. +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: # -# msprime is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. # -# msprime is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with msprime. If not, see . +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. # """ Test cases for the intervals module. diff --git a/python/tskit/intervals.py b/python/tskit/intervals.py index fe1baf7d75..60c630c5c2 100644 --- a/python/tskit/intervals.py +++ b/python/tskit/intervals.py @@ -1,20 +1,25 @@ +# MIT License # +# Copyright (c) 2023 Tskit Developers # Copyright (C) 2020-2021 University of Oxford # -# This file is part of msprime. +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: # -# msprime is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. # -# msprime is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with msprime. If not, see . +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. # """ Utilities for working with intervals and interval maps. From 2e21a477d256497a53e2694a5aa1b6d30d573aee Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Thu, 10 Nov 2022 12:29:36 +0000 Subject: [PATCH 17/84] Allow alignments to be specified To match the msprime `text_table` function --- python/tests/test_util.py | 21 +++++++++++++++++++++ python/tskit/util.py | 28 ++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/python/tests/test_util.py b/python/tests/test_util.py index cc4f9d45da..b4f8d314a9 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -489,6 +489,27 @@ def test_unicode_table(): ) +def test_unicode_table_alignments(): + assert ( + util.unicode_table( + [["5", "6", "7", "8"], ["90", "10", "11", "12"]], + header=["1", "2", "3", "4"], + alignments="<>><", + ) + == textwrap.dedent( + """ + ╔══╤══╤══╤══╗ + ║1 │2 │3 │4 ║ + ╠══╪══╪══╪══╣ + ║5 │ 6│ 7│8 ║ + ╟──┼──┼──┼──╢ + ║90│10│11│12║ + ╚══╧══╧══╧══╝ + """ + )[1:] + ) + + def test_set_printoptions(): assert tskit._print_options == {"max_lines": 40} util.set_print_options(max_lines=None) diff --git a/python/tskit/util.py b/python/tskit/util.py index 9baa298ceb..576b2be8bb 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -320,7 +320,7 @@ def obj_to_collapsed_html(d, name=None, open_depth=0): :param str name: Name for this object :param int open_depth: By default sub-sections are collapsed. If this number is - non-zero the first layers up to open_depth will be opened. + non-zero the first layers up to open_depth will be opened. :return: The HTML as a string :rtype: str """ @@ -369,19 +369,25 @@ def render_metadata(md, length=40): return truncate_string_end(str(md), length) -def unicode_table(rows, title=None, header=None, row_separator=True): +def unicode_table( + rows, title=None, header=None, row_separator=True, column_alignments=None +): """ Convert a table (list of lists) of strings to a unicode table. If a row contains the string "__skipped__NNN" then "skipped N rows" is displayed. :param list[list[str]] rows: List of rows, each of which is a list of strings for - each cell. The first column will be left justified, the others right. Each row must - have the same number of cells. + each cell. Each row must have the same number of cells. :param str title: If specified the first output row will be a single cell - containing this string, left-justified. [optional] + containing this string, left-justified. [optional] :param list[str] header: Specifies a row above the main rows which will be in double - lined borders and left justified. Must be same length as each row. [optional] + lined borders and left justified. Must be same length as each row. [optional] :param boolean row_separator: If True add lines between each row. [Default: True] + :param column_alignments str: A string of the same length as the number of cells in + a row (i.e. columns) where each character specifies an alignment such as ``<``, + ``>`` or ``^`` as used in Python's string formatting mini-language. If ``None``, + set the first column to be left justified and the remaining columns to be right + justified [Default: ``None``] :return: The table as a string :rtype: str """ @@ -392,6 +398,8 @@ def unicode_table(rows, title=None, header=None, row_separator=True): widths = [ max(len(row[i_col]) for row in all_rows) for i_col in range(len(all_rows[0])) ] + if column_alignments is None: + column_alignments = "<" + ">" * (len(widths) - 1) out = [] inner_width = sum(widths) + len(header or rows[0]) - 1 if title is not None: @@ -423,9 +431,13 @@ def unicode_table(rows, title=None, header=None, row_separator=True): else: if i != 0 and not last_skipped and row_separator: out.append(f"╟{'┼'.join('─' * w for w in widths)}╢\n") + out.append( - f"║{row[0].ljust(widths[0])}│" - f"{'│'.join(cell.rjust(w) for cell, w in zip(row[1:], widths[1:]))}║\n" + "║" + + "│".join( + f"{r:{a}{w}}" for r, w, a in zip(row, widths, column_alignments) + ) + + "║\n" ) last_skipped = False From 44d07a6a78f4ae6a8a717828560c86b380a8cb9f Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Thu, 10 Nov 2022 13:50:51 +0000 Subject: [PATCH 18/84] Place html table formatting in util.py Allows it to be used by other tabular outputting routines. Also put the "limit" functionality into a single util function --- python/tests/test_util.py | 4 +- python/tskit/tables.py | 122 ++++++-------------------------------- python/tskit/util.py | 52 +++++++++++++++- 3 files changed, 70 insertions(+), 108 deletions(-) diff --git a/python/tests/test_util.py b/python/tests/test_util.py index b4f8d314a9..1d78dd0a08 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -489,12 +489,12 @@ def test_unicode_table(): ) -def test_unicode_table_alignments(): +def test_unicode_table_column_alignments(): assert ( util.unicode_table( [["5", "6", "7", "8"], ["90", "10", "11", "12"]], header=["1", "2", "3", "4"], - alignments="<>><", + column_alignments="<>><", ) == textwrap.dedent( """ diff --git a/python/tskit/tables.py b/python/tskit/tables.py index b875115621..2de36496f1 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -27,7 +27,6 @@ import collections.abc import dataclasses import datetime -import itertools import json import numbers import warnings @@ -657,41 +656,10 @@ def __str__(self): return util.unicode_table(rows, header=headers, row_separator=False) def _repr_html_(self): - """ - Called by jupyter notebooks to render tables - """ headers, rows = self._text_header_and_rows( limit=tskit._print_options["max_lines"] ) - headers = "".join(f"{header}" for header in headers) - rows = ( - f'{row[11:]}' - f" rows skipped (tskit.set_print_options)" - if "__skipped__" in row - else "".join(f"{cell}" for cell in row) - for row in rows - ) - rows = "".join(f"{row}\n" for row in rows) - return f""" -
- - - - - {headers} - - - - {rows} - -
-
- """ + return util.html_table(rows, header=headers) def _columns_all_integer(self, *colnames): # For displaying floating point values without loads of decimal places @@ -852,15 +820,8 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "flags", "location", "parents", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1105,16 +1066,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "flags", "population", "individual", "time", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: row = self[j] if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") @@ -1306,16 +1260,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "left", "right", "parent", "child", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places = 0 if self._columns_all_integer("left", "right") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1528,17 +1475,10 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "left", "right", "node", "source", "dest", "time", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places_coords = 0 if self._columns_all_integer("left", "right") else 8 decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1748,16 +1688,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "position", "ancestral_state", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places = 0 if self._columns_all_integer("position") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1971,17 +1904,10 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "site", "node", "time", "derived_state", "parent", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) # Currently mutations do not have discretised times: this for consistency decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -2232,15 +2158,8 @@ def add_row(self, metadata=None): def _text_header_and_rows(self, limit=None): headers = ("id", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -2490,15 +2409,8 @@ def append_columns( def _text_header_and_rows(self, limit=None): headers = ("id", "timestamp", "record") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: diff --git a/python/tskit/util.py b/python/tskit/util.py index 576b2be8bb..ca8ca77f99 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -23,6 +23,7 @@ Module responsible for various utility functions used in other modules. """ import dataclasses +import itertools import json import numbers import os @@ -370,7 +371,7 @@ def render_metadata(md, length=40): def unicode_table( - rows, title=None, header=None, row_separator=True, column_alignments=None + rows, *, title=None, header=None, row_separator=True, column_alignments=None ): """ Convert a table (list of lists) of strings to a unicode table. If a row contains @@ -445,6 +446,41 @@ def unicode_table( return "".join(out) +def html_table(rows, *, header): + """ + Called by jupyter notebooks to render tables + """ + headers = "".join(f"{h}" for h in header) + rows = ( + f'{row[11:]}' + f" rows skipped (tskit.set_print_options)" + if "__skipped__" in row + else "".join(f"{cell}" for cell in row) + for row in rows + ) + rows = "".join(f"{row}\n" for row in rows) + return f""" +
+ + + + + {headers} + + + + {rows} + +
+
+ """ + + def tree_sequence_html(ts): table_rows = "".join( f""" @@ -686,6 +722,20 @@ def set_print_options(*, max_lines=40): tskit._print_options = {"max_lines": max_lines} +def truncate_rows(num_rows, limit=None): + """ + Return a list of indexes into a set of rows, but is limit is set, truncate the + number of rows and place a `-1` instead of the intermediate indexes + """ + if limit is None or num_rows <= limit: + return range(num_rows) + return itertools.chain( + range(limit // 2), + [-1], + range(num_rows - (limit - (limit // 2)), num_rows), + ) + + def random_nucleotides(length: numbers.Number, *, seed: Union[int, None] = None) -> str: """ Returns a random string of nucleotides of the specified length. Characters From 805b1534734cca5d1fe8a2dacc68501d731e5116 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Thu, 10 Nov 2022 13:52:11 +0000 Subject: [PATCH 19/84] Fixup intervals module Rework a little to use the built-in tskit unicode and html output functions --- python/CHANGELOG.rst | 10 +- python/tests/test_intervals.py | 241 +++++++++++++++++---------------- python/tskit/__init__.py | 3 +- python/tskit/intervals.py | 225 +++++------------------------- python/tskit/tables.py | 5 +- python/tskit/util.py | 9 +- 6 files changed, 171 insertions(+), 322 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 1f33ba6d09..c593ba2ef9 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -25,6 +25,10 @@ or is an isolated non-sample node. (:user:`szhan`, :pr:`2618`, :issue:`2616`) +- The ``msprime.RateMap`` class has been ported into tskit: functionality should + be identical to the version in msprime, apart from minor changes in the formatting + of tabular text output (:user:`hyanwong`, :user:`jeromekelleher`, :pr:`2678`) + **Breaking Changes** - the ``filter_populations``, ``filter_individuals``, and ``filter_sites`` @@ -62,7 +66,7 @@ - Single statistics computed with ``TreeSequence.general_stat`` are now returned as numpy scalars if windows=None, AND; samples is a single - list or None (for a 1-way stat), OR indexes is None or a single list of + list or None (for a 1-way stat), OR indexes is None or a single list of length k (instead of a list of length-k lists). (:user:`gtsambos`, :pr:`2417`, :issue:`2308`) @@ -77,10 +81,10 @@ **Performance improvements** - TreeSequence.link_ancestors no longer continues to process edges once all - of the sample and ancestral nodes have been accounted for, improving memory + of the sample and ancestral nodes have been accounted for, improving memory overhead and overall performance (:user:`gtsambos`, :pr:`2456`, :issue:`2442`) - + -------------------- [0.5.2] - 2022-07-29 -------------------- diff --git a/python/tests/test_intervals.py b/python/tests/test_intervals.py index a4ac6cf984..f4ac31dfea 100644 --- a/python/tests/test_intervals.py +++ b/python/tests/test_intervals.py @@ -33,11 +33,12 @@ import textwrap import xml -import msprime import numpy as np import pytest from numpy.testing import assert_array_equal +import tskit + class TestRateMapErrors: @pytest.mark.parametrize( @@ -53,40 +54,40 @@ class TestRateMapErrors: ) def test_bad_input(self, position, rate): with pytest.raises(ValueError): - msprime.RateMap(position=position, rate=rate) + tskit.RateMap(position=position, rate=rate) def test_zero_length_interval(self): with pytest.raises(ValueError, match=r"at indexes \[2 4\]"): - msprime.RateMap(position=[0, 1, 1, 2, 2, 3], rate=[0, 0, 0, 0, 0]) + tskit.RateMap(position=[0, 1, 1, 2, 2, 3], rate=[0, 0, 0, 0, 0]) def test_bad_length(self): positions = np.array([0, 1, 2]) rates = np.array([0, 1, 2]) with pytest.raises(ValueError, match="one less entry"): - msprime.RateMap(position=positions, rate=rates) + tskit.RateMap(position=positions, rate=rates) def test_bad_first_pos(self): positions = np.array([1, 2, 3]) rates = np.array([1, 1]) with pytest.raises(ValueError, match="First position"): - msprime.RateMap(position=positions, rate=rates) + tskit.RateMap(position=positions, rate=rates) def test_bad_rate(self): positions = np.array([0, 1, 2]) rates = np.array([1, -1]) with pytest.raises(ValueError, match="negative.*1"): - msprime.RateMap(position=positions, rate=rates) + tskit.RateMap(position=positions, rate=rates) def test_bad_rate_with_missing(self): positions = np.array([0, 1, 2]) rates = np.array([np.nan, -1]) with pytest.raises(ValueError, match="negative.*1"): - msprime.RateMap(position=positions, rate=rates) + tskit.RateMap(position=positions, rate=rates) def test_read_only(self): positions = np.array([0, 0.25, 0.5, 0.75, 1]) rates = np.array([0.125, 0.25, 0.5, 0.75]) # 1 shorter than positions - rate_map = msprime.RateMap(position=positions, rate=rates) + rate_map = tskit.RateMap(position=positions, rate=rates) assert np.all(rates == rate_map.rate) assert np.all(positions == rate_map.position) with pytest.raises(AttributeError): @@ -113,12 +114,12 @@ def test_read_only(self): class TestGetRateAllKnown: examples = [ - msprime.RateMap(position=[0, 1], rate=[0]), - msprime.RateMap(position=[0, 1], rate=[0.1]), - msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), - msprime.RateMap(position=[0, 1, 2], rate=[0, 0.2]), - msprime.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), - msprime.RateMap(position=range(100), rate=range(99)), + tskit.RateMap(position=[0, 1], rate=[0]), + tskit.RateMap(position=[0, 1], rate=[0.1]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + tskit.RateMap(position=range(100), rate=range(99)), ] @pytest.mark.parametrize("rate_map", examples) @@ -145,16 +146,16 @@ def test_get_rate_right(self, rate_map): class TestOperations: examples = [ - msprime.RateMap(position=[0, 1], rate=[0]), - msprime.RateMap(position=[0, 1], rate=[0.1]), - msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), - msprime.RateMap(position=[0, 1, 2], rate=[0, 0.2]), - msprime.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), - msprime.RateMap(position=range(100), rate=range(99)), + tskit.RateMap.uniform(sequence_length=1, rate=0), + tskit.RateMap.uniform(sequence_length=1, rate=0.1), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + tskit.RateMap(position=range(100), rate=range(99)), # Missing data - msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0]), - msprime.RateMap(position=[0, 1, 2], rate=[0, np.nan]), - msprime.RateMap(position=[0, 1, 2, 3], rate=[0, np.nan, 1]), + tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0]), + tskit.RateMap(position=[0, 1, 2], rate=[0, np.nan]), + tskit.RateMap(position=[0, 1, 2, 3], rate=[0, np.nan, 1]), ] @pytest.mark.parametrize("rate_map", examples) @@ -220,17 +221,23 @@ def test_map_semantics(self, rate_map): for x in rate_map.mid[rate_map.missing]: assert x not in rate_map + def test_asdict(self): + rate_map = tskit.RateMap.uniform(sequence_length=2, rate=4) + d = rate_map.asdict() + assert_array_equal(d["position"], np.array([0.0, 2.0])) + assert_array_equal(d["rate"], np.array([4.0])) + class TestFindIndex: def test_one_interval(self): - rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) for j in range(10): assert rate_map.find_index(j) == 0 assert rate_map.find_index(0.0001) == 0 assert rate_map.find_index(9.999) == 0 def test_two_intervals(self): - rate_map = msprime.RateMap(position=[0, 5, 10], rate=[0.1, 0.1]) + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[0.1, 0.1]) assert rate_map.find_index(0) == 0 assert rate_map.find_index(0.0001) == 0 assert rate_map.find_index(4.9999) == 0 @@ -240,7 +247,7 @@ def test_two_intervals(self): assert rate_map.find_index(9.999) == 1 def test_three_intervals(self): - rate_map = msprime.RateMap(position=[0, 5, 10, 15], rate=[0.1, 0.1, 0.1]) + rate_map = tskit.RateMap(position=[0, 5, 10, 15], rate=[0.1, 0.1, 0.1]) assert rate_map.find_index(0) == 0 assert rate_map.find_index(0.0001) == 0 assert rate_map.find_index(4.9999) == 0 @@ -254,13 +261,13 @@ def test_three_intervals(self): assert rate_map.find_index(14.9999) == 2 def test_out_of_bounds(self): - rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) for bad_value in [-1, -0.0001, 10, 10.0001, 1e9]: with pytest.raises(KeyError, match="out of bounds"): rate_map.find_index(bad_value) def test_input_types(self): - rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) assert rate_map.find_index(0) == 0 assert rate_map.find_index(0.0) == 0 assert rate_map.find_index(np.zeros(1)[0]) == 0 @@ -269,37 +276,37 @@ def test_input_types(self): class TestSimpleExamples: def test_all_missing_one_interval(self): with pytest.raises(ValueError, match="missing data"): - msprime.RateMap(position=[0, 10], rate=[np.nan]) + tskit.RateMap(position=[0, 10], rate=[np.nan]) def test_all_missing_two_intervals(self): with pytest.raises(ValueError, match="missing data"): - msprime.RateMap(position=[0, 5, 10], rate=[np.nan, np.nan]) + tskit.RateMap(position=[0, 5, 10], rate=[np.nan, np.nan]) def test_count(self): - rate_map = msprime.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) assert rate_map.num_intervals == 2 assert rate_map.num_missing_intervals == 1 assert rate_map.num_non_missing_intervals == 1 def test_missing_arrays(self): - rate_map = msprime.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) assert list(rate_map.missing) == [True, False] assert list(rate_map.non_missing) == [False, True] def test_missing_at_start_mean_rate(self): positions = np.array([0, 0.5, 1, 2]) rates = np.array([np.nan, 0, 1]) - rate_map = msprime.RateMap(position=positions, rate=rates) + rate_map = tskit.RateMap(position=positions, rate=rates) assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) def test_missing_at_end_mean_rate(self): positions = np.array([0, 1, 1.5, 2]) rates = np.array([1, 0, np.nan]) - rate_map = msprime.RateMap(position=positions, rate=rates) + rate_map = tskit.RateMap(position=positions, rate=rates) assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) def test_interval_properties_all_known(self): - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) assert list(rate_map.left) == [0, 1, 2] assert list(rate_map.right) == [1, 2, 3] assert list(rate_map.mid) == [0.5, 1.5, 2.5] @@ -307,55 +314,55 @@ def test_interval_properties_all_known(self): assert list(rate_map.mass) == [0.1, 0.2, 0.3] def test_pickle_non_missing(self): - r1 = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + r1 = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) r2 = pickle.loads(pickle.dumps(r1)) assert r1 == r2 def test_pickle_missing(self): - r1 = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, np.nan, 0.3]) + r1 = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, np.nan, 0.3]) r2 = pickle.loads(pickle.dumps(r1)) assert r1 == r2 def test_get_cumulative_mass_all_known(self): - rate_map = msprime.RateMap(position=[0, 10, 20, 30], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 10, 20, 30], rate=[0.1, 0.2, 0.3]) assert list(rate_map.mass) == [1, 2, 3] assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 3, 6] def test_cumulative_mass_missing(self): - rate_map = msprime.RateMap(position=[0, 10, 20, 30], rate=[0.1, np.nan, 0.3]) + rate_map = tskit.RateMap(position=[0, 10, 20, 30], rate=[0.1, np.nan, 0.3]) assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 1, 4] class TestDisplay: def test_str(self): - rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) - s = """ - ┌──────────────────────────────────┐ - │left │right │ mid│ span│ rate│ - ├──────────────────────────────────┤ - │0 │10 │ 5│ 10│ 0.1│ - └──────────────────────────────────┘ + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + s = """\ + ╔════╤═════╤═══╤════╤════╗ + ║left│right│mid│span│rate║ + ╠════╪═════╪═══╪════╪════╣ + ║0 │10 │ 5│ 10│ 0.1║ + ╚════╧═════╧═══╧════╧════╝ """ assert textwrap.dedent(s) == str(rate_map) def test_str_scinot(self): - rate_map = msprime.RateMap(position=[0, 10], rate=[0.000001]) - s = """ - ┌───────────────────────────────────┐ - │left │right │ mid│ span│ rate│ - ├───────────────────────────────────┤ - │0 │10 │ 5│ 10│ 1e-06│ - └───────────────────────────────────┘ + rate_map = tskit.RateMap(position=[0, 10], rate=[0.000001]) + s = """\ + ╔════╤═════╤═══╤════╤═════╗ + ║left│right│mid│span│rate ║ + ╠════╪═════╪═══╪════╪═════╣ + ║0 │10 │ 5│ 10│1e-06║ + ╚════╧═════╧═══╧════╧═════╝ """ assert textwrap.dedent(s) == str(rate_map) def test_repr(self): - rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) s = "RateMap(position=array([ 0., 10.]), rate=array([0.1]))" assert repr(rate_map) == s def test_repr_html(self): - rate_map = msprime.RateMap(position=[0, 10], rate=[0.1]) + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) html = rate_map._repr_html_() root = xml.etree.ElementTree.fromstring(html) assert root.tag == "div" @@ -365,8 +372,8 @@ def test_repr_html(self): def test_long_table(self): n = 100 - rate_map = msprime.RateMap(position=range(n + 1), rate=[0.1] * n) - headers, data = rate_map._display_table() + rate_map = tskit.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._text_header_and_rows(limit=20) assert len(headers) == 5 assert len(data) == 21 # check some left values @@ -375,8 +382,8 @@ def test_long_table(self): def test_short_table(self): n = 10 - rate_map = msprime.RateMap(position=range(n + 1), rate=[0.1] * n) - headers, data = rate_map._display_table() + rate_map = tskit.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._text_header_and_rows(limit=20) assert len(headers) == 5 assert len(data) == n # check some left values. @@ -386,22 +393,22 @@ def test_short_table(self): class TestRateMapIsMapping: def test_items(self): - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) items = list(rate_map.items()) assert items[0] == (0.5, 0.1) assert items[1] == (1.5, 0.2) assert items[2] == (2.5, 0.3) def test_keys(self): - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) assert list(rate_map.keys()) == [0.5, 1.5, 2.5] def test_values(self): - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) assert list(rate_map.values()) == [0.1, 0.2, 0.3] def test_in_points(self): - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) # Any point within the map are True for x in [0, 0.5, 1, 2.9999]: assert x in rate_map @@ -410,7 +417,7 @@ def test_in_points(self): assert x not in rate_map def test_in_slices(self): - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) # slices that are within the map are "in" for x in [slice(0, 0.5), slice(0, 1), slice(0, 2), slice(2, 3), slice(0, 3)]: assert x in rate_map @@ -422,41 +429,41 @@ def test_in_slices(self): assert slice(-2, -1) not in rate_map def test_other_types_not_in(self): - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) for other_type in [None, "sdf", "123", {}, [], Exception]: assert other_type not in rate_map def test_len(self): - rate_map = msprime.RateMap(position=[0, 1], rate=[0.1]) + rate_map = tskit.RateMap(position=[0, 1], rate=[0.1]) assert len(rate_map) == 1 - rate_map = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) assert len(rate_map) == 2 - rate_map = msprime.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) assert len(rate_map) == 3 def test_immutable(self): - rate_map = msprime.RateMap(position=[0, 1], rate=[0.1]) + rate_map = tskit.RateMap(position=[0, 1], rate=[0.1]) with pytest.raises(TypeError, match="item assignment"): rate_map[0] = 1 with pytest.raises(TypeError, match="item deletion"): del rate_map[0] def test_eq(self): - r1 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) - r2 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + r2 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) assert r1 == r1 assert r1 == r2 - r2 = msprime.RateMap(position=[0, 1, 3], rate=[0.1, 0.2]) + r2 = tskit.RateMap(position=[0, 1, 3], rate=[0.1, 0.2]) assert r1 != r2 - assert msprime.RateMap(position=[0, 1], rate=[0.1]) != msprime.RateMap( + assert tskit.RateMap(position=[0, 1], rate=[0.1]) != tskit.RateMap( position=[0, 1], rate=[0.2] ) - assert msprime.RateMap(position=[0, 1], rate=[0.1]) != msprime.RateMap( + assert tskit.RateMap(position=[0, 1], rate=[0.1]) != tskit.RateMap( position=[0, 10], rate=[0.1] ) def test_getitem_value(self): - rate_map = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) assert rate_map[0] == 0.1 assert rate_map[0.5] == 0.1 assert rate_map[1] == 0.2 @@ -471,7 +478,7 @@ def test_getitem_value(self): assert rate_map[decimal.Decimal(1)] == 0.2 def test_getitem_slice(self): - r1 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) # The semantics of the slice() function are tested elsewhere. assert r1[:] == r1.copy() assert r1[:] is not r1 @@ -480,7 +487,7 @@ def test_getitem_slice(self): assert r1[0.5:1.5] == r1.slice(left=0.5, right=1.5) def test_getitem_slice_step(self): - r1 = msprime.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) # Trying to set a "step" is a error with pytest.raises(TypeError, match="interval slicing"): r1[0:3:1] @@ -488,20 +495,20 @@ def test_getitem_slice_step(self): class TestMappingMissingData: def test_get_missing(self): - rate_map = msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) with pytest.raises(KeyError, match="within a missing interval"): rate_map[0] with pytest.raises(KeyError, match="within a missing interval"): rate_map[0.999] def test_in_missing(self): - rate_map = msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) assert 0 not in rate_map assert 0.999 not in rate_map assert 1 in rate_map def test_keys_missing(self): - rate_map = msprime.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) assert list(rate_map.keys()) == [1.5] @@ -509,13 +516,13 @@ class TestGetIntermediates: def test_get_rate(self): positions = np.array([0, 1, 2]) rates = np.array([1, 4]) - rate_map = msprime.RateMap(position=positions, rate=rates) + rate_map = tskit.RateMap(position=positions, rate=rates) assert np.all(rate_map.get_rate([0.5, 1.5]) == rates) def test_get_rate_out_of_bounds(self): positions = np.array([0, 1, 2]) rates = np.array([1, 4]) - rate_map = msprime.RateMap(position=positions, rate=rates) + rate_map = tskit.RateMap(position=positions, rate=rates) with pytest.raises(ValueError, match="out of bounds"): rate_map.get_rate([1, -0.1]) with pytest.raises(ValueError, match="out of bounds"): @@ -524,14 +531,14 @@ def test_get_rate_out_of_bounds(self): def test_get_cumulative_mass(self): positions = np.array([0, 1, 2]) rates = np.array([1, 4]) - rate_map = msprime.RateMap(position=positions, rate=rates) + rate_map = tskit.RateMap(position=positions, rate=rates) assert np.allclose(rate_map.get_cumulative_mass([0.5, 1.5]), np.array([0.5, 3])) assert rate_map.get_cumulative_mass([2]) == rate_map.total_mass def test_get_bad_cumulative_mass(self): positions = np.array([0, 1, 2]) rates = np.array([1, 4]) - rate_map = msprime.RateMap(position=positions, rate=rates) + rate_map = tskit.RateMap(position=positions, rate=rates) with pytest.raises(ValueError, match="positions"): rate_map.get_cumulative_mass([1, -0.1]) with pytest.raises(ValueError, match="positions"): @@ -541,7 +548,7 @@ def test_get_bad_cumulative_mass(self): class TestSlice: def test_slice_no_params(self): # test RateMap.slice(..., trim=False) - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) b = a.slice() assert a.sequence_length == b.sequence_length assert_array_equal(a.position, b.position) @@ -549,7 +556,7 @@ def test_slice_no_params(self): assert a == b def test_slice_left_examples(self): - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) b = a.slice(left=50) assert a.sequence_length == b.sequence_length assert_array_equal([0, 50, 100, 200, 300, 400], b.position) @@ -566,7 +573,7 @@ def test_slice_left_examples(self): assert_array_equal([np.nan, 1, 2, 3], b.rate) def test_slice_right_examples(self): - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) b = a.slice(right=300) assert a.sequence_length == b.sequence_length assert_array_equal([0, 100, 200, 300, 400], b.position) @@ -578,7 +585,7 @@ def test_slice_right_examples(self): assert_array_equal([0, 1, 2, np.nan], b.rate) def test_slice_left_right_examples(self): - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) b = a.slice(left=50, right=300) assert a.sequence_length == b.sequence_length assert_array_equal([0, 50, 100, 200, 300, 400], b.position) @@ -602,7 +609,7 @@ def test_slice_left_right_examples(self): def test_slice_right_missing(self): # If we take a right-slice into a trailing missing region, # we should recover the same map. - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, np.nan]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, np.nan]) b = a.slice(right=350) assert a.sequence_length == b.sequence_length assert_array_equal(a.position, b.position) @@ -614,7 +621,7 @@ def test_slice_right_missing(self): assert_array_equal(a.rate, b.rate) def test_slice_left_missing(self): - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[np.nan, 1, 2, 3]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[np.nan, 1, 2, 3]) b = a.slice(left=50) assert a.sequence_length == b.sequence_length assert_array_equal(a.position, b.position) @@ -627,7 +634,7 @@ def test_slice_left_missing(self): def test_slice_with_floats(self): # test RateMap.slice(..., trim=False) with floats - a = msprime.RateMap( + a = tskit.RateMap( position=[np.pi * x for x in [0, 100, 200, 300, 400]], rate=[0, 1, 2, 3] ) b = a.slice(left=50 * np.pi) @@ -636,21 +643,21 @@ def test_slice_with_floats(self): assert_array_equal([np.nan] + list(a.rate), b.rate) def test_slice_trim_left(self): - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) b = a.slice(left=100, trim=True) - assert b == msprime.RateMap(position=[0, 100, 200, 300], rate=[2, 3, 4]) + assert b == tskit.RateMap(position=[0, 100, 200, 300], rate=[2, 3, 4]) b = a.slice(left=50, trim=True) - assert b == msprime.RateMap(position=[0, 50, 150, 250, 350], rate=[1, 2, 3, 4]) + assert b == tskit.RateMap(position=[0, 50, 150, 250, 350], rate=[1, 2, 3, 4]) def test_slice_trim_right(self): - a = msprime.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) b = a.slice(right=300, trim=True) - assert b == msprime.RateMap(position=[0, 100, 200, 300], rate=[1, 2, 3]) + assert b == tskit.RateMap(position=[0, 100, 200, 300], rate=[1, 2, 3]) b = a.slice(right=350, trim=True) - assert b == msprime.RateMap(position=[0, 100, 200, 300, 350], rate=[1, 2, 3, 4]) + assert b == tskit.RateMap(position=[0, 100, 200, 300, 350], rate=[1, 2, 3, 4]) def test_slice_error(self): - recomb_map = msprime.RateMap(position=[0, 100], rate=[1]) + recomb_map = tskit.RateMap(position=[0, 100], rate=[1]) with pytest.raises(KeyError): recomb_map.slice(left=-1) with pytest.raises(KeyError): @@ -672,7 +679,7 @@ def test_read_hapmap_simple(self): chr1 2 x 0.000001 x chr1 3 x 0.000006 x x""" ) - rm = msprime.RateMap.read_hapmap(hapfile) + rm = tskit.RateMap.read_hapmap(hapfile) assert_array_equal(rm.position, [0, 1, 2, 3]) assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) @@ -685,7 +692,7 @@ def test_read_hapmap_from_filename(self, tmp_path): chr1 2 x 0.000001 x chr1 3 x 0.000006 x x""" ) - rm = msprime.RateMap.read_hapmap(tmp_path / "hapfile.txt") + rm = tskit.RateMap.read_hapmap(tmp_path / "hapfile.txt") assert_array_equal(rm.position, [0, 1, 2, 3]) assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) @@ -696,7 +703,7 @@ def test_read_hapmap_empty(self): HEADER""" ) with pytest.raises(ValueError, match="Empty"): - msprime.RateMap.read_hapmap(hapfile) + tskit.RateMap.read_hapmap(hapfile) def test_read_hapmap_col_pos(self): hapfile = io.StringIO( @@ -706,7 +713,7 @@ def test_read_hapmap_col_pos(self): 0.000001 1 x 0.000006 2 x x""" ) - rm = msprime.RateMap.read_hapmap(hapfile, position_col=1, map_col=0) + rm = tskit.RateMap.read_hapmap(hapfile, position_col=1, map_col=0) assert_array_equal(rm.position, [0, 1, 2]) assert np.allclose(rm.rate, [1e-8, 5e-8]) @@ -719,7 +726,7 @@ def test_read_hapmap_map_and_rate(self): chr1 2 2 0.000006 x x""" ) with pytest.raises(ValueError, match="both rate_col and map_col"): - msprime.RateMap.read_hapmap(hapfile, rate_col=2, map_col=3) + tskit.RateMap.read_hapmap(hapfile, rate_col=2, map_col=3) def test_read_hapmap_duplicate_pos(self): hapfile = io.StringIO( @@ -730,7 +737,7 @@ def test_read_hapmap_duplicate_pos(self): 0.000006 2 x x""" ) with pytest.raises(ValueError, match="same columns"): - msprime.RateMap.read_hapmap(hapfile, map_col=1) + tskit.RateMap.read_hapmap(hapfile, map_col=1) def test_read_hapmap_nonzero_rate_start(self): hapfile = io.StringIO( @@ -739,7 +746,7 @@ def test_read_hapmap_nonzero_rate_start(self): chr1 1 5 x chr1 2 0 x x x""" ) - rm = msprime.RateMap.read_hapmap(hapfile, rate_col=2) + rm = tskit.RateMap.read_hapmap(hapfile, rate_col=2) assert_array_equal(rm.position, [0, 1, 2]) assert_array_equal(rm.rate, [np.nan, 5e-8]) @@ -751,7 +758,7 @@ def test_read_hapmap_nonzero_rate_end(self): chr1 2 1 x x x""" ) with pytest.raises(ValueError, match="last entry.*must be zero"): - msprime.RateMap.read_hapmap(hapfile, rate_col=2) + tskit.RateMap.read_hapmap(hapfile, rate_col=2) def test_read_hapmap_gzipped(self, tmp_path): hapfile = os.path.join(tmp_path, "hapmap.txt.gz") @@ -760,7 +767,7 @@ def test_read_hapmap_gzipped(self, tmp_path): gzfile.write(b"chr1 0 1\n") gzfile.write(b"chr1 1 5.5\n") gzfile.write(b"chr1 2 0\n") - rm = msprime.RateMap.read_hapmap(hapfile, rate_col=2) + rm = tskit.RateMap.read_hapmap(hapfile, rate_col=2) assert_array_equal(rm.position, [0, 1, 2]) assert_array_equal(rm.rate, [1e-8, 5.5e-8]) @@ -772,7 +779,7 @@ def test_read_hapmap_nonzero_map_start(self): chr1 2 x 0.000001 x chr1 3 x 0.000006 x x x""" ) - rm = msprime.RateMap.read_hapmap(hapfile) + rm = tskit.RateMap.read_hapmap(hapfile) assert_array_equal(rm.position, [0, 1, 2, 3]) assert np.allclose(rm.rate, [1e-8, 0, 5e-8]) @@ -785,7 +792,7 @@ def test_read_hapmap_bad_nonzero_map_start(self): chr1 2 x 0.000006 x x x""" ) with pytest.raises(ValueError, match="start.*must be zero"): - msprime.RateMap.read_hapmap(hapfile) + tskit.RateMap.read_hapmap(hapfile) def test_sequence_length(self): hapfile = io.StringIO( @@ -796,12 +803,12 @@ def test_sequence_length(self): chr1 2 x 0.000006 x x x""" ) # test identical seq len - rm = msprime.RateMap.read_hapmap(hapfile, sequence_length=2) + rm = tskit.RateMap.read_hapmap(hapfile, sequence_length=2) assert_array_equal(rm.position, [0, 1, 2]) assert np.allclose(rm.rate, [1e-8, 5e-8]) hapfile.seek(0) - rm = msprime.RateMap.read_hapmap(hapfile, sequence_length=10) + rm = tskit.RateMap.read_hapmap(hapfile, sequence_length=10) assert_array_equal(rm.position, [0, 1, 2, 10]) assert np.allclose(rm.rate, [1e-8, 5e-8, np.nan], equal_nan=True) @@ -814,7 +821,7 @@ def test_bad_sequence_length(self): chr1 2 x 0.000006 x x x""" ) with pytest.raises(ValueError, match="sequence_length"): - msprime.RateMap.read_hapmap(hapfile, sequence_length=1.999) + tskit.RateMap.read_hapmap(hapfile, sequence_length=1.999) def test_no_header(self): data = """\ @@ -824,9 +831,9 @@ def test_no_header(self): hapfile_noheader = io.StringIO(data) hapfile_header = io.StringIO("chr pos rate cM\n" + data) with pytest.raises(ValueError): - msprime.RateMap.read_hapmap(hapfile_header, has_header=False) - rm1 = msprime.RateMap.read_hapmap(hapfile_header) - rm2 = msprime.RateMap.read_hapmap(hapfile_noheader, has_header=False) + tskit.RateMap.read_hapmap(hapfile_header, has_header=False) + rm1 = tskit.RateMap.read_hapmap(hapfile_header) + rm2 = tskit.RateMap.read_hapmap(hapfile_noheader, has_header=False) assert_array_equal(rm1.rate, rm2.rate) assert_array_equal(rm1.position, rm2.position) @@ -844,8 +851,8 @@ def test_hapmap_fragment(self): 1 9752154 0.0864316558730679 1.25601286485381 1 9881751 0.0 1.26721414815999""" ) - rm1 = msprime.RateMap.read_hapmap(hapfile) + rm1 = tskit.RateMap.read_hapmap(hapfile) hapfile.seek(0) - rm2 = msprime.RateMap.read_hapmap(hapfile, rate_col=2) + rm2 = tskit.RateMap.read_hapmap(hapfile, rate_col=2) assert np.allclose(rm1.position, rm2.position) assert np.allclose(rm1.rate, rm2.rate, equal_nan=True) diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py index 09e16091e5..c1b153be04 100644 --- a/python/tskit/__init__.py +++ b/python/tskit/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -90,3 +90,4 @@ from tskit.util import * # NOQA from tskit.metadata import * # NOQA from tskit.text_formats import * # NOQA +from tskit.intervals import RateMap # NOQA diff --git a/python/tskit/intervals.py b/python/tskit/intervals.py index 60c630c5c2..ef371531e0 100644 --- a/python/tskit/intervals.py +++ b/python/tskit/intervals.py @@ -27,12 +27,12 @@ from __future__ import annotations import collections.abc -import itertools import numbers -import warnings import numpy as np -from msprime import core + +import tskit +import tskit.util as util class RateMap(collections.abc.Mapping): @@ -334,53 +334,42 @@ def __getitem__(self, key): # if the returned array contains any nans though. raise KeyError("Key {key} not in map") - def _display_table(self): - def format_row(left, right, mid, span, rate): - return [ - f"{left:.10g}", - f"{right:.10g}", - f"{mid:.10g}", - f"{span:.10g}", - f"{rate:.2g}", - ] - - def format_slice(start, end): - return list( - itertools.starmap( - format_row, - zip( - self.left[start:end], - self.right[start:end], - self.mid[start:end], - self.span[start:end], - self.rate[start:end], - ), + def _text_header_and_rows(self, limit=None): + headers = ("left", "right", "mid", "span", "rate") + num_rows = len(self.left) + rows = [] + row_indexes = util.truncate_rows(num_rows, limit) + for j in row_indexes: + if j == -1: + rows.append(f"__skipped__{num_rows-limit}") + else: + rows.append( + [ + f"{self.left[j]:.10g}", + f"{self.right[j]:.10g}", + f"{self.mid[j]:.10g}", + f"{self.span[j]:.10g}", + f"{self.rate[j]:.2g}", + ] ) - ) - - if self.num_intervals < 40: - data = format_slice(0, None) - else: - data = format_slice(0, 10) - data.append(["⋯"] * 5) - data += format_slice(-10, None) - - return ["left", "right", "mid", "span", "rate"], data + return headers, rows def __str__(self): - titles, data = self._display_table() - data = [[[item] for item in row] for row in data] - table = core.text_table( - caption="", - column_titles=[[title] for title in titles], + header, rows = self._text_header_and_rows( + limit=tskit._print_options["max_lines"] + ) + table = util.unicode_table( + rows=rows, + header=header, column_alignments="<<>>>", - data=data, ) return table def _repr_html_(self): - col_titles, data = self._display_table() - return core.html_table("", col_titles, data) + header, rows = self._text_header_and_rows( + limit=tskit._print_options["max_lines"] + ) + return util.html_table(rows, header=header) def __repr__(self): return f"RateMap(position={repr(self.position)}, rate={repr(self.rate)})" @@ -610,155 +599,3 @@ def read_hapmap( if end != physical_positions[-1]: rate[-1] = np.nan return RateMap(position=physical_positions, rate=rate) - - -class RecombinationMap: - """ - A RecombinationMap represents the changing rates of recombination - along a chromosome. This is defined via two lists of numbers: - ``positions`` and ``rates``, which must be of the same length. - Given an index j in these lists, the rate of recombination - per base per generation is ``rates[j]`` over the interval - ``positions[j]`` to ``positions[j + 1]``. Consequently, the first - position must be zero, and by convention the last rate value - is also required to be zero (although it is not used). - - .. important:: - This class is deprecated (but supported indefinitely); - please use the :class:`.RateMap` class in new code. - In particular, note that when specifying ``rates`` in the - the :class:`.RateMap` class we now require an array - of length :math:`n - 1` (this class requires an array - of length :math:`n` in which the last entry is zero). - - :param list positions: The positions (in bases) denoting the - distinct intervals where recombination rates change. These can - be floating point values. - :param list rates: The list of rates corresponding to the supplied - ``positions``. Recombination rates are specified per base, - per generation. - :param int num_loci: **This parameter is no longer supported.** - Must be either None (meaning a continuous genome of the - finest possible resolution) or be equal to ``positions[-1]`` - (meaning a discrete genome). Any other value will result in - an error. Please see the :ref:`sec_legacy_0x_genome_discretisation` - section for more information. - """ - - def __init__(self, positions, rates, num_loci=None, map_start=0): - # Used as an internal flag for the 0.x simulate() function. This allows - # us to emulate the discrete-sites behaviour of 0.x code. - self._is_discrete = num_loci == positions[-1] - if num_loci is not None and num_loci != positions[-1]: - raise ValueError( - "The RecombinationMap interface is deprecated and only " - "partially supported. If you wish to simulate a number of " - "discrete loci, you must set num_loci == the sequence length. " - "If you wish to simulate recombination process on as fine " - "a map as possible, please omit the num_loci parameter (or set " - "to None). Otherwise, num_loci is no longer supported and " - "the behaviour of msprime 0.x cannot be emulated. Please " - "consider upgrading your code to the version 1.x APIs." - ) - self.map = RateMap(position=positions, rate=rates[:-1]) - - @classmethod - def uniform_map(cls, length, rate, num_loci=None): - """ - Returns a :class:`.RecombinationMap` instance in which the recombination - rate is constant over a chromosome of the specified length. - The legacy ``num_loci`` option is no longer supported and should not be used. - - :param float length: The length of the chromosome. - :param float rate: The rate of recombination per unit of sequence length - along this chromosome. - :param int num_loci: This parameter is no longer supported. - """ - return cls([0, length], [rate, 0], num_loci=num_loci) - - @classmethod - def read_hapmap(cls, filename): - """ - Parses the specified file in HapMap format. - - .. warning:: - This method is deprecated, use the :meth:`.RateMap.read_hapmap` - method instead. - - :param str filename: The name of the file to be parsed. This may be - in plain text or gzipped plain text. - :return: A RecombinationMap object. - """ - warnings.warn( - "RecombinationMap.read_hapmap() is deprecated. " - "Use RateMap.read_hapmap() instead.", - FutureWarning, - ) - rate_map = RateMap.read_hapmap(filename, position_col=1, rate_col=2) - # Mark anything missing as 0 for backwards compatibility. This will - # ensure that simulate() never trims parts of the tree sequence. - rate = rate_map.rate.copy() - rate[rate_map.missing] = 0 - return cls(rate_map.position, np.append(rate, 0)) - - @property - def mean_recombination_rate(self): - """ - Return the weighted mean recombination rate - across all windows of the entire recombination map. - """ - return self.map.mean_rate - - def get_total_recombination_rate(self): - """ - Returns the effective recombination rate for this genetic map. - This is the weighted mean of the rates across all intervals. - """ - return self.map.total_mass - - def physical_to_genetic(self, x): - return self.map.get_cumulative_mass(x) - - def genetic_to_physical(self, genetic_x): - if self.map.total_mass == 0: - # If we have a zero recombination rate throughout then everything - # except L maps to 0. - return self.get_sequence_length() if genetic_x > 0 else 0 - if genetic_x == 0: - return self.map.position[0] - # TODO refactor this to this to use get_cumulative_mass() function / add the - # corresponding high-level function to the rate map. - index = np.searchsorted(self.map._cumulative_mass, genetic_x) - 1 - y = ( - self.map.position[index] - + (genetic_x - self.map._cumulative_mass[index]) / self.map.rate[index] - ) - return y - - def physical_to_discrete_genetic(self, physical_x): - raise ValueError("Discrete genetic space is no longer supported") - - def get_per_locus_recombination_rate(self): - raise ValueError("Genetic loci are no longer supported") - - def get_num_loci(self): - raise ValueError("num_loci is no longer supported") - - def get_size(self): - return len(self.map.position) - - def get_positions(self): - return list(self.map.position) - - def get_rates(self): - return list(self.map.rate) + [0] - - def get_sequence_length(self): - return self.map.sequence_length - - def get_length(self): - # Deprecated: use get_sequence_length() instead - return self.get_sequence_length() - - def asdict(self): - return self.map.asdict() diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 2de36496f1..c11f8e74d8 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -1,7 +1,7 @@ # # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -656,6 +656,9 @@ def __str__(self): return util.unicode_table(rows, header=headers, row_separator=False) def _repr_html_(self): + """ + Called e.g. by jupyter notebooks to render tables + """ headers, rows = self._text_header_and_rows( limit=tskit._print_options["max_lines"] ) diff --git a/python/tskit/util.py b/python/tskit/util.py index ca8ca77f99..86761a181d 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -447,9 +447,6 @@ def unicode_table( def html_table(rows, *, header): - """ - Called by jupyter notebooks to render tables - """ headers = "".join(f"{h}" for h in header) rows = ( f'{row[11:]}' @@ -724,8 +721,8 @@ def set_print_options(*, max_lines=40): def truncate_rows(num_rows, limit=None): """ - Return a list of indexes into a set of rows, but is limit is set, truncate the - number of rows and place a `-1` instead of the intermediate indexes + Return a list of indexes into a set of rows, but if a ``limit`` is set, truncate the + number of rows and place a single ``-1`` entry, instead of the intermediate indexes """ if limit is None or num_rows <= limit: return range(num_rows) From 11e0d50f82d2ff8549508b9544d77f75ae52adaf Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 13 Jan 2023 11:56:02 +0000 Subject: [PATCH 20/84] Update linting action --- .github/workflows/docs.yml | 2 +- .github/workflows/tests.yml | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c2fcc823e5..0a0490c8ea 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -36,7 +36,7 @@ jobs: id: cache with: path: venv - key: docs-venv-v6-${{ hashFiles(env.REQUIREMENTS) }} + key: docs-venv-v7-${{ hashFiles(env.REQUIREMENTS) }} - name: Build virtualenv if: steps.cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index abbbeae854..e7e0183cb6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,10 +8,10 @@ on: jobs: pre-commit: name: Lint - runs-on: ubuntu-18.04 + runs-on: ubuntu-latest steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.10.0 + uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - uses: actions/checkout@v3 @@ -21,9 +21,8 @@ jobs: - name: install clang-format if: steps.clang_format.outputs.cache-hit != 'true' run: | - sudo apt-get remove -y clang-6.0 libclang-common-6.0-dev libclang1-6.0 libllvm6.0 - sudo apt-get autoremove - sudo apt-get install clang-format clang-format-6.0 + sudo pip install clang-format==6.0.1 + sudo ln -s /usr/local/bin/clang-format /usr/local/bin/clang-format-6.0 - uses: pre-commit/action@v3.0.0 benchmark: From 5ee30fb68ff5d92ec8a82156a789a86aada03aff Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 13 Jan 2023 11:42:56 +0000 Subject: [PATCH 21/84] Implement TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS Closes #2662 --- c/tests/test_trees.c | 31 ++++++++++++++ c/tskit/tables.c | 31 ++++++++------ c/tskit/tables.h | 23 ++++++++-- python/_tskitmodule.c | 17 +++++--- python/tests/simplify.py | 27 ++++++------ python/tests/test_highlevel.py | 41 ++++++++++++------ python/tests/test_lowlevel.py | 3 ++ python/tests/test_topology.py | 78 +++++++++++++++++++++++++++------- python/tskit/tables.py | 29 ++++++------- python/tskit/trees.py | 43 +++++++++++++------ 10 files changed, 235 insertions(+), 88 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index bb70d273d6..af6e7c644f 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -3323,6 +3323,36 @@ test_simplest_no_node_filter(void) tsk_treeseq_free(&ts); } +static void +test_simplest_no_update_flags(void) +{ + const char *nodes = "0 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts, simplified; + tsk_id_t sample_ids[] = { 0, 1 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + /* We have a mixture of sample and non-samples in the input tables */ + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 2, + TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS | TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, + NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + tsk_treeseq_free(&ts); +} + static void test_simplest_map_mutations(void) { @@ -8093,6 +8123,7 @@ main(int argc, char **argv) { "test_simplest_population_filter", test_simplest_population_filter }, { "test_simplest_individual_filter", test_simplest_individual_filter }, { "test_simplest_no_node_filter", test_simplest_no_node_filter }, + { "test_simplest_no_update_flags", test_simplest_no_update_flags }, { "test_simplest_map_mutations", test_simplest_map_mutations }, { "test_simplest_nonbinary_map_mutations", test_simplest_nonbinary_map_mutations }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index ff50e6398d..3039dbbb38 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -8808,16 +8808,18 @@ static tsk_id_t TSK_WARN_UNUSED simplifier_record_node(simplifier_t *self, tsk_id_t input_id) { tsk_node_t node; - tsk_flags_t flags; + bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS); tsk_node_table_get_row_unsafe(&self->input_tables.nodes, (tsk_id_t) input_id, &node); - /* Zero out the sample bit */ - flags = node.flags & (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; - if (self->is_sample[input_id]) { - flags |= TSK_NODE_IS_SAMPLE; + if (update_flags) { + /* Zero out the sample bit */ + node.flags &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; + if (self->is_sample[input_id]) { + node.flags |= TSK_NODE_IS_SAMPLE; + } } self->node_id_map[input_id] = (tsk_id_t) self->tables->nodes.num_rows; - return tsk_node_table_add_row(&self->tables->nodes, flags, node.time, + return tsk_node_table_add_row(&self->tables->nodes, node.flags, node.time, node.population, node.individual, node.metadata, node.metadata_length); } @@ -9108,6 +9110,7 @@ simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples) tsk_size_t j; const tsk_size_t num_nodes = self->input_tables.nodes.num_rows; bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); + bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS); tsk_flags_t *node_flags = self->tables->nodes.flags; tsk_id_t *node_id_map = self->node_id_map; @@ -9123,13 +9126,17 @@ simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples) } } else { tsk_bug_assert(self->tables->nodes.num_rows == num_nodes); - /* The node table has not been changed */ - for (j = 0; j < num_nodes; j++) { - /* Reset the sample flags */ - node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; - if (self->is_sample[j]) { - node_flags[j] |= TSK_NODE_IS_SAMPLE; + if (update_flags) { + for (j = 0; j < num_nodes; j++) { + /* Reset the sample flags */ + node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; + if (self->is_sample[j]) { + node_flags[j] |= TSK_NODE_IS_SAMPLE; + } } + } + + for (j = 0; j < num_nodes; j++) { node_id_map[j] = (tsk_id_t) j; } } diff --git a/c/tskit/tables.h b/c/tskit/tables.h index bd69b9cc95..a3496bb9ef 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -694,6 +694,10 @@ first. */ #define TSK_SIMPLIFY_NO_FILTER_NODES (1 << 7) /** +Do not update the sample status of nodes as a result of simplification. +*/ +#define TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS (1 << 8) +/** Reduce the topological information in the tables to the minimum necessary to represent the trees that contain sites. If there are zero sites this will result in an zero output edges. When the number of sites is greater than zero, @@ -3919,9 +3923,10 @@ or :c:macro:`TSK_NULL` if the node has been removed. Thus, ``node_map`` must be of at least ``self->nodes.num_rows`` :c:type:`tsk_id_t` values. If the `TSK_SIMPLIFY_NO_FILTER_NODES` option is specified, the node table will be -unaltered except for changing the sample status of nodes that were samples in the -input tables, but not in the specified list of sample IDs (if provided). The -``node_map`` (if specified) will always be the identity mapping, such that +unaltered except for changing the sample status of nodes (but see the +`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option below) and to update references +to other tables that may have changed as a result of filtering (see below). +The ``node_map`` (if specified) will always be the identity mapping, such that ``node_map[u] == u`` for all nodes. Note also that the order of the list of samples is not important in this case. @@ -3941,6 +3946,17 @@ sample status flag of nodes. may be entirely unreferenced entities in the input tables, which are not affected by whether we filter nodes or not. +By default, the node sample flags are updated by unsetting the +:c:macro:`TSK_NODE_IS_SAMPLE` flag for all nodes and subsequently setting it +for the nodes provided as input to this function. The +`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option will prevent this from occuring, +making it the responsibility of calling code to keep track of the ultimate +sample status of nodes. Using this option in conjunction with +`TSK_SIMPLIFY_NO_FILTER_NODES` (and without the +`TSK_SIMPLIFY_FILTER_POPULATIONS` and `TSK_SIMPLIFY_FILTER_INDIVIDUALS` +options) guarantees that the node table will not be written to during the +lifetime of this function. + The table collection will always be unindexed after simplify successfully completes. .. note:: Migrations are currently not supported by simplify, and an error will @@ -3956,6 +3972,7 @@ Options can be specified by providing one or more of the following bitwise - :c:macro:`TSK_SIMPLIFY_FILTER_POPULATIONS` - :c:macro:`TSK_SIMPLIFY_FILTER_INDIVIDUALS` - :c:macro:`TSK_SIMPLIFY_NO_FILTER_NODES` +- :c:macro:`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` - :c:macro:`TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY` - :c:macro:`TSK_SIMPLIFY_KEEP_UNARY` - :c:macro:`TSK_SIMPLIFY_KEEP_INPUT_ROOTS` diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 6d5248b209..5c7adedb09 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6589,21 +6589,23 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) int filter_individuals = false; int filter_populations = false; int filter_nodes = true; + int update_sample_flags = true; int keep_unary = false; int keep_unary_in_individuals = false; int keep_input_roots = false; int reduce_to_site_topology = false; - static char *kwlist[] = { "samples", "filter_sites", "filter_populations", - "filter_individuals", "filter_nodes", "reduce_to_site_topology", "keep_unary", - "keep_unary_in_individuals", "keep_input_roots", NULL }; + static char *kwlist[] + = { "samples", "filter_sites", "filter_populations", "filter_individuals", + "filter_nodes", "update_sample_flags", "reduce_to_site_topology", + "keep_unary", "keep_unary_in_individuals", "keep_input_roots", NULL }; if (TableCollection_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiii", kwlist, &samples, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiiii", kwlist, &samples, &filter_sites, &filter_populations, &filter_individuals, &filter_nodes, - &reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals, - &keep_input_roots)) { + &update_sample_flags, &reduce_to_site_topology, &keep_unary, + &keep_unary_in_individuals, &keep_input_roots)) { goto out; } samples_array = (PyArrayObject *) PyArray_FROMANY( @@ -6625,6 +6627,9 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) if (!filter_nodes) { options |= TSK_SIMPLIFY_NO_FILTER_NODES; } + if (!update_sample_flags) { + options |= TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS; + } if (reduce_to_site_topology) { options |= TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY; } diff --git a/python/tests/simplify.py b/python/tests/simplify.py index ef6ebdc25f..6505aec05d 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -111,7 +111,8 @@ def __init__( keep_unary=False, keep_unary_in_individuals=False, keep_input_roots=False, - filter_nodes=True, # If this is False, the order in `sample` is ignored + filter_nodes=True, + update_sample_flags=True, ): self.ts = ts self.n = len(sample) @@ -121,6 +122,7 @@ def __init__( self.filter_populations = filter_populations self.filter_individuals = filter_individuals self.filter_nodes = filter_nodes + self.update_sample_flags = update_sample_flags self.keep_unary = keep_unary self.keep_unary_in_individuals = keep_unary_in_individuals self.keep_input_roots = keep_input_roots @@ -152,14 +154,14 @@ def __init__( # NOTE In the C implementation we would really just not touch the # original tables. self.tables.nodes.replace_with(self.ts.tables.nodes) - # TODO make this optional somehow - flags = self.tables.nodes.flags - # Zero out other sample flags - flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE) - flags[sample] |= tskit.NODE_IS_SAMPLE - self.tables.nodes.flags = flags.astype(np.uint32) - self.node_id_map[:] = np.arange(ts.num_nodes) + if self.update_sample_flags: + flags = self.tables.nodes.flags + # Zero out other sample flags + flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE) + flags[sample] |= tskit.NODE_IS_SAMPLE + self.tables.nodes.flags = flags.astype(np.uint32) + self.node_id_map[:] = np.arange(ts.num_nodes) for sample_id in sample: self.add_ancestry(sample_id, 0, self.sequence_length, sample_id) else: @@ -178,10 +180,11 @@ def record_node(self, input_id): """ node = self.ts.node(input_id) flags = node.flags - # Need to zero out the sample flag - flags &= ~tskit.NODE_IS_SAMPLE - if self.is_sample[input_id]: - flags |= tskit.NODE_IS_SAMPLE + if self.update_sample_flags: + # Need to zero out the sample flag + flags &= ~tskit.NODE_IS_SAMPLE + if self.is_sample[input_id]: + flags |= tskit.NODE_IS_SAMPLE output_id = self.tables.nodes.append(node.replace(flags=flags)) self.node_id_map[input_id] = output_id return output_id diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 3e3a0de8cf..c926121aea 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -254,19 +254,19 @@ def get_internal_samples_examples(): # Set all nodes to be samples. flags[:] = tskit.NODE_IS_SAMPLE nodes.flags = flags - ret.append(("all nodes samples", tables.tree_sequence())) + ret.append(("all_nodes_samples", tables.tree_sequence())) # Set just internal nodes to be samples. flags[:] = 0 flags[n:] = tskit.NODE_IS_SAMPLE nodes.flags = flags - ret.append(("internal nodes samples", tables.tree_sequence())) + ret.append(("internal_nodes_samples", tables.tree_sequence())) # Set a mixture of internal and leaf samples. flags[:] = 0 flags[n // 2 : n + n // 2] = tskit.NODE_IS_SAMPLE nodes.flags = flags - ret.append(("mixture of internal and leaf samples", tables.tree_sequence())) + ret.append(("mixed_internal_leaf_samples", tables.tree_sequence())) return ret @@ -281,7 +281,7 @@ def get_decapitated_examples(): ts = msprime.simulate(20, recombination_rate=1, random_seed=1234) assert ts.num_trees > 2 - ret.append(("decapitate recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4))) + ret.append(("decapitate_recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4))) return ret @@ -302,7 +302,7 @@ def get_bottleneck_examples(): demographic_events=bottlenecks, random_seed=n, ) - yield (f"bottleneck n={n}", ts) + yield (f"bottleneck_n={n}", ts) def get_back_mutation_examples(): @@ -337,13 +337,13 @@ def make_example_tree_sequences(): ) ts = tsutil.insert_random_ploidy_individuals(ts, 4, seed=seed) yield ( - f"n={n} m={m} rho={rho}", + f"n={n}_m={m}_rho={rho}", tsutil.add_random_metadata(ts, seed=seed), ) seed += 1 for name, ts in get_bottleneck_examples(): yield ( - f"{name} mutated", + f"{name}_mutated", msprime.mutate( ts, rate=0.1, @@ -352,7 +352,7 @@ def make_example_tree_sequences(): ), ) ts = tskit.Tree.generate_balanced(8).tree_sequence - yield ("rev node order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1))) + yield ("rev_node_order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1))) ts = msprime.sim_ancestry( 8, sequence_length=40, recombination_rate=0.1, random_seed=seed ) @@ -361,20 +361,20 @@ def make_example_tree_sequences(): ts = tables.tree_sequence() assert ts.num_trees > 1 yield ( - "back mutations", + "back_mutations", tsutil.insert_branch_mutations(ts, mutations_per_branch=2), ) ts = tsutil.insert_multichar_mutations(ts) yield ("multichar", ts) - yield ("multichar w/ metadata", tsutil.add_random_metadata(ts)) + yield ("multichar_no_metadata", tsutil.add_random_metadata(ts)) tables = ts.dump_tables() tables.nodes.flags = np.zeros_like(tables.nodes.flags) - yield ("no samples", tables.tree_sequence()) # no samples + yield ("no_samples", tables.tree_sequence()) # no samples tables = ts.dump_tables() tables.edges.clear() - yield ("empty tree", tables.tree_sequence()) # empty tree + yield ("empty_tree", tables.tree_sequence()) # empty tree yield ( - "empty ts", + "empty_ts", tskit.TableCollection(sequence_length=1).tree_sequence(), ) # empty tree seq yield ("all_fields", tsutil.all_fields_ts()) @@ -384,6 +384,8 @@ def make_example_tree_sequences(): def get_example_tree_sequences(pytest_params=True): + # NOTE: pytest names should not contain spaces and be shell safe so + # that they can be easily specified on the command line. if pytest_params: return [pytest.param(ts, id=name) for name, ts in _examples] else: @@ -2785,6 +2787,19 @@ def test_simplify_migrations_fails(self): with pytest.raises(_tskit.LibraryError): ts.simplify() + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_no_update_sample_flags_no_filter_nodes(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + k = min(ts.num_samples, 3) + subset = ts.samples()[:k] + ts1 = ts.simplify(subset) + ts2 = ts.simplify(subset, update_sample_flags=False, filter_nodes=False) + assert ts1.num_samples == len(subset) + assert ts2.num_samples == ts.num_samples + assert ts1.num_edges == ts2.num_edges + assert ts2.tables.nodes == ts.tables.nodes + class TestMinMaxTime: def get_example_tree_sequence(self, use_unknown_time): diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 31abe8b6e3..e599ad62c9 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -349,6 +349,8 @@ def test_simplify_bad_args(self): tc.simplify([0, 1], filter_populations="x") with pytest.raises(TypeError): tc.simplify([0, 1], filter_nodes="x") + with pytest.raises(TypeError): + tc.simplify([0, 1], update_sample_flags="x") with pytest.raises(_tskit.LibraryError): tc.simplify([0, -1]) @@ -360,6 +362,7 @@ def test_simplify_bad_args(self): "filter_populations", "filter_individuals", "filter_nodes", + "update_sample_flags", "reduce_to_site_topology", "keep_unary", "keep_unary_in_individuals", diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 356c3fb316..d564ec0590 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -2686,7 +2686,7 @@ def verify_simplify( filter_sites=filter_sites, keep_input_roots=keep_input_roots, filter_nodes=filter_nodes, - compare_lib=True, # TMP + compare_lib=True, ) if debug: print("before") @@ -4757,6 +4757,7 @@ def do_simplify( filter_nodes=True, keep_unary=False, keep_input_roots=False, + update_sample_flags=True, ): """ Runs the Python test implementation of simplify. @@ -4772,6 +4773,7 @@ def do_simplify( filter_nodes=filter_nodes, keep_unary=keep_unary, keep_input_roots=keep_input_roots, + update_sample_flags=update_sample_flags, ) new_ts, node_map = s.simplify() if compare_lib: @@ -4781,28 +4783,16 @@ def do_simplify( filter_individuals=filter_individuals, filter_populations=filter_populations, filter_nodes=filter_nodes, + update_sample_flags=update_sample_flags, keep_unary=keep_unary, keep_input_roots=keep_input_roots, map_nodes=True, ) lib_tables1 = sts.dump_tables() - lib_tables2 = ts.dump_tables() - lib_node_map2 = lib_tables2.simplify( - samples, - filter_sites=filter_sites, - keep_unary=keep_unary, - keep_input_roots=keep_input_roots, - filter_individuals=filter_individuals, - filter_populations=filter_populations, - filter_nodes=filter_nodes, - ) - py_tables = new_ts.dump_tables() py_tables.assert_equals(lib_tables1, ignore_provenance=True) - py_tables.assert_equals(lib_tables2, ignore_provenance=True) assert all(node_map == lib_node_map1) - assert all(node_map == lib_node_map2) return new_ts, node_map @@ -6091,6 +6081,64 @@ def test_mutations_on_removed_branches(self): assert ts2.num_mutations == 0 +class TestSimplifyNoUpdateSampleFlags: + """ + Tests for simplify when we don't update the sample flags. + """ + + def test_simple_case_filter_nodes(self): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts1 = tskit.Tree.generate_balanced(4).tree_sequence + ts2, node_map = do_simplify( + ts1, + [0, 1, 6], + update_sample_flags=False, + ) + # Because we don't retain 2 and 3 here, they don't stay as + # samples. But, we specified 6 as a sample, so it's coming + # through where it would ordinarily be dropped. + + # 2.00┊ 2 ┊ + # ┊ ┃ ┊ + # 1.00┊ 3 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 ┊ + # 0 1 + assert list(ts2.nodes_flags) == [1, 1, 0, 0] + tree = ts2.first() + assert list(tree.parent_array) == [3, 3, -1, 2, -1] + + def test_simple_case_no_filter_nodes(self): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts1 = tskit.Tree.generate_balanced(4).tree_sequence + ts2, node_map = do_simplify( + ts1, + [0, 1, 6], + update_sample_flags=False, + filter_nodes=False, + ) + + # 2.00┊ 6 ┊ + # ┊ ┃ ┊ + # 1.00┊ 4 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + assert list(ts2.nodes_flags) == list(ts1.nodes_flags) + tree = ts2.first() + assert list(tree.parent_array) == [4, 4, -1, -1, 6, -1, -1, -1] + + class TestMapToAncestors: """ Tests the AncestorMap class. diff --git a/python/tskit/tables.py b/python/tskit/tables.py index c11f8e74d8..eb3b90ffd5 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3264,6 +3264,7 @@ def simplify( filter_individuals=None, filter_sites=None, filter_nodes=None, + update_sample_flags=None, keep_unary=False, keep_unary_in_individuals=None, keep_input_roots=False, @@ -3278,22 +3279,12 @@ def simplify( result, resulting in a NodeTable where only the first ``len(samples)`` nodes are marked as samples. The mapping from node IDs in the current set of tables to their equivalent values in the simplified tables is - also returned as a numpy array. If an array ``a`` is returned by this + returned as a numpy array. If an array ``a`` is returned by this function and ``u`` is the ID of a node in the input table, then ``a[u]`` is the ID of this node in the output table. For any node ``u`` that is not mapped into the output tables, this mapping will equal ``-1``. - If ``filter_nodes`` is False, then the output node table will be - unchanged except for updating the sample status of nodes. Nodes that - are in the specified list of ``samples`` will be marked as samples - in the output, and nodes that are currently marked as samples in - the node table but **not** in the specified list of ``samples`` - will have their sample flag cleared. Note also that the order of - the ``samples`` list is not meaningful when ``filter_nodes`` is False. - The returned node mapping is always the identity mapping, such that - ``a[u] == u`` for all nodes. - Tables operated on by this function must: be sorted (see :meth:`TableCollection.sort`), have children be born strictly after their parents, and the intervals on which any node is a child must be @@ -3301,10 +3292,11 @@ def simplify( requirements to specify a valid tree sequence (but the resulting tables will). - This is identical to :meth:`TreeSequence.simplify` but acts *in place* to - alter the data in this :class:`TableCollection`. Please see the - :meth:`TreeSequence.simplify` method for a description of the remaining - parameters. + .. seealso:: + This is identical to :meth:`TreeSequence.simplify` but acts *in place* to + alter the data in this :class:`TableCollection`. Please see the + :meth:`TreeSequence.simplify` method for a description of the remaining + parameters. :param list[int] samples: A list of node IDs to retain as samples. They need not be nodes marked as samples in the original tree sequence, but @@ -3331,6 +3323,10 @@ def simplify( potential change to the node table may be to change the node flags (if ``samples`` is specified and different from the existing samples). (Default: None, treated as True) + :param bool update_sample_flags: If True, update node flags to so that + nodes in the specified list of samples have the NODE_IS_SAMPLE + flag after simplification, and nodes that are not in this list + do not. (Default: None, treated as True) :param bool keep_unary: If True, preserve unary nodes (i.e. nodes with exactly one child) that exist on the path from samples to root. (Default: False) @@ -3374,6 +3370,8 @@ def simplify( filter_sites = True if filter_nodes is None: filter_nodes = True + if update_sample_flags is None: + update_sample_flags = True if keep_unary_in_individuals is None: keep_unary_in_individuals = False @@ -3383,6 +3381,7 @@ def simplify( filter_individuals=filter_individuals, filter_populations=filter_populations, filter_nodes=filter_nodes, + update_sample_flags=update_sample_flags, reduce_to_site_topology=reduce_to_site_topology, keep_unary=keep_unary, keep_unary_in_individuals=keep_unary_in_individuals, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index d560e29eec..0b55882a68 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6543,6 +6543,7 @@ def simplify( filter_individuals=None, filter_sites=None, filter_nodes=None, + update_sample_flags=None, keep_unary=False, keep_unary_in_individuals=None, keep_input_roots=False, @@ -6557,18 +6558,10 @@ def simplify( original tree sequence, or :data:`tskit.NULL` (-1) if ``u`` is no longer present in the simplified tree sequence. - In the returned tree sequence the only nodes flagged as samples are - those passed as ``samples``: all others are not flagged as samples. If - ``filter_nodes`` is True (the default), nodes in the returned tree - sequence are also reordered such that the node with ID ``0`` - corresponds to ``samples[0]``, node ``1`` corresponds to ``samples[1]`` - etc., and the remaining node IDs are allocated sequentially in time - order. Alternatively, if ``filter_nodes`` is False, the node order is - not changed, and the order of IDs passed to ``samples`` is irrelevant. - - If you wish to simplify a set of tables that do not satisfy all - requirements for building a TreeSequence, then use - :meth:`TableCollection.simplify`. + .. note:: + If you wish to simplify a set of tables that do not satisfy all + requirements for building a TreeSequence, then use + :meth:`TableCollection.simplify`. If the ``reduce_to_site_topology`` parameter is True, the returned tree sequence will contain only topological information that is necessary to @@ -6587,6 +6580,27 @@ def simplify( simplification. By setting these parameters to False, however, the corresponding tables can be preserved without changes. + If ``filter_nodes`` is False, then the output node table will be + unchanged except for updating the sample status of nodes and any ID + remappings caused by filtering individuals and populations (if the + ``filter_individuals`` and ``filter_populations`` options are enabled). + Nodes that are in the specified list of ``samples`` will be marked as + samples in the output, and nodes that are currently marked as samples + in the node table but not in the specified list of ``samples`` will + have their :data:`tskit.NODE_IS_SAMPLE` flag cleared. Note also that + the order of the ``samples`` list is not meaningful when + ``filter_nodes`` is False. In this case, the returned node mapping is + always the identity mapping, such that ``a[u] == u`` for all nodes. + + Setting the ``update_sample_flags`` parameter to False disables the + automatic sample status update of nodes (described above) from + occuring, making it the responsibility of calling code to keep track of + the ultimate sample status of nodes. This is an advanced option, mostly + of use when combined with the ``filter_nodes=False``, + ``filter_populations=False`` and ``filter_individuals=False`` options, + which then guarantees that the node table will not be altered by + simplification. + :param list[int] samples: A list of node IDs to retain as samples. They need not be nodes marked as samples in the original tree sequence, but will constitute the entire set of samples in the returned tree sequence. @@ -6616,6 +6630,10 @@ def simplify( potential change to the node table may be to change the node flags (if ``samples`` is specified and different from the existing samples). (Default: None, treated as True) + :param bool update_sample_flags: If True, update node flags to so that + nodes in the specified list of samples have the NODE_IS_SAMPLE + flag after simplification, and nodes that are not in this list + do not. (Default: None, treated as True) :param bool keep_unary: If True, preserve unary nodes (i.e., nodes with exactly one child) that exist on the path from samples to root. (Default: False) @@ -6648,6 +6666,7 @@ def simplify( filter_individuals=filter_individuals, filter_sites=filter_sites, filter_nodes=filter_nodes, + update_sample_flags=update_sample_flags, keep_unary=keep_unary, keep_unary_in_individuals=keep_unary_in_individuals, keep_input_roots=keep_input_roots, From 4cdb0bf47d8cc52fa6830f2e09c63338258ae40b Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 12 Jan 2023 15:40:23 +0000 Subject: [PATCH 22/84] Improve error messages for legacy and compressed formats --- python/_tskitmodule.c | 10 ++-- .../requirements/CI-complete/requirements.txt | 1 + .../CI-tests-pip/requirements.txt | 3 +- python/requirements/development.txt | 1 + python/tests/test_file_format.py | 20 ++++++- python/tests/test_fileobj.py | 59 ++++++++++++++++++- python/tskit/tables.py | 18 ++++-- python/tskit/trees.py | 2 + python/tskit/util.py | 30 ++++++++++ 9 files changed, 128 insertions(+), 16 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 5c7adedb09..5ed421c242 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -221,10 +221,10 @@ handle_library_error(int err) { int kas_err; const char *not_kas_format_msg - = "File not in kastore format. If this file " - "was generated by msprime < 0.6.0 (June 2018) it uses the old HDF5-based " - "format which can no longer be read directly. Please convert to the new " - "kastore format using the ``tskit upgrade`` command."; + = "File not in kastore format. Either the file is corrupt or it is not a " + "tskit tree sequence file. It may be a legacy HDF file upgradable with " + "`tskit upgrade` or a compressed tree sequence file that can be decompressed " + "with `tszip`."; const char *ibd_pairs_not_stored_msg = "Sample pairs are not stored by default " "in the IdentitySegments object returned by ibd_segments(), and you have " diff --git a/python/requirements/CI-complete/requirements.txt b/python/requirements/CI-complete/requirements.txt index 6dcee67d6d..a582f42387 100644 --- a/python/requirements/CI-complete/requirements.txt +++ b/python/requirements/CI-complete/requirements.txt @@ -19,4 +19,5 @@ pytest==7.1.3 pytest-cov==4.0.0 pytest-xdist==2.5.0 svgwrite==1.4.3 +tszip==0.2.2 xmlunittest==0.5.0 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index f83b753a5d..e9e16c64e5 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -11,4 +11,5 @@ biopython==1.79 dendropy==4.5.2 networkx==2.6.3 # Held at 2.6.3 for Python 3.7 compatibility msgpack==1.0.4 -newick==1.3.2 \ No newline at end of file +newick==1.3.2 +tszip==0.2.2 \ No newline at end of file diff --git a/python/requirements/development.txt b/python/requirements/development.txt index 4b3753de4c..ba48881723 100644 --- a/python/requirements/development.txt +++ b/python/requirements/development.txt @@ -36,6 +36,7 @@ sphinx-jupyterbook-latex sphinxcontrib-prettyspecialmethods tqdm tskit-book-theme +tszip pydata_sphinx_theme>=0.7.2 svgwrite>=1.1.10 xmlunittest diff --git a/python/tests/test_file_format.py b/python/tests/test_file_format.py index c41327b288..26abc0f2d8 100644 --- a/python/tests/test_file_format.py +++ b/python/tests/test_file_format.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2016-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -36,12 +36,12 @@ import msprime import numpy as np import pytest +import tszip as tszip import tests.tsutil as tsutil import tskit import tskit.exceptions as exceptions - CURRENT_FILE_MAJOR = 12 CURRENT_FILE_MINOR = 7 @@ -262,11 +262,17 @@ def test_format_too_old_raised_for_hdf5(self): ] for filename in files: path = os.path.join(test_data_dir, "hdf5-formats", filename) + with pytest.raises( exceptions.FileFormatError, - match="uses the old HDF5-based format which can no longer", + match="appears to be in HDF5 format", ): tskit.load(path) + with pytest.raises( + exceptions.FileFormatError, + match="appears to be in HDF5 format", + ): + tskit.TableCollection.load(path) def test_msprime_v_0_5_0(self): path = os.path.join(test_data_dir, "hdf5-formats", "msprime-0.5.0_v10.0.hdf5") @@ -511,6 +517,14 @@ def test_no_h5py(self): with pytest.raises(ImportError, match=msg): tskit.dump_legacy(ts, path) + def test_tszip_file(self): + ts = msprime.simulate(5) + tszip.compress(ts, self.temp_file) + with pytest.raises(tskit.FileFormatError, match="appears to be in zip format"): + tskit.load(self.temp_file) + with pytest.raises(tskit.FileFormatError, match="appears to be in zip format"): + tskit.TableCollection.load(self.temp_file) + class TestDumpFormat(TestFileFormat): """ diff --git a/python/tests/test_fileobj.py b/python/tests/test_fileobj.py index e4d05b63b6..1740094462 100644 --- a/python/tests/test_fileobj.py +++ b/python/tests/test_fileobj.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,7 @@ import traceback import pytest +import tszip from pytest import fixture import tskit @@ -308,3 +309,59 @@ def verify_stream(self, ts_list, client_fd): def test_single_then_multi(self, ts_fixture, replicate_ts_fixture, client_fd): self.verify_stream([ts_fixture], client_fd) self.verify_stream(replicate_ts_fixture, client_fd) + + +def write_to_fifo(path, file_path): + with open(path, "wb") as fifo: + with open(file_path, "rb") as file: + fifo.write(file.read()) + + +def read_from_fifo(path, expected_exception, error_text, read_func): + with open(path) as fifo: + with pytest.raises(expected_exception, match=error_text): + read_func(fifo) + + +def write_and_read_from_fifo(fifo_path, file_path, expected_exception, error_text): + os.mkfifo(fifo_path) + for read_func in [tskit.load, tskit.TableCollection.load]: + read_process = multiprocessing.Process( + target=read_from_fifo, + args=(fifo_path, expected_exception, error_text, read_func), + ) + read_process.start() + write_process = multiprocessing.Process( + target=write_to_fifo, args=(fifo_path, file_path) + ) + write_process.start() + write_process.join(timeout=3) + read_process.join(timeout=3) + + +@pytest.mark.skipif(IS_WINDOWS, reason="No FIFOs on Windows") +class TestBadStream: + def test_bad_stream(self, tmp_path): + fifo_path = tmp_path / "fifo" + bad_file_path = tmp_path / "bad_file" + bad_file_path.write_bytes(b"bad data") + write_and_read_from_fifo( + fifo_path, bad_file_path, tskit.FileFormatError, "not in kastore format" + ) + + def test_legacy_stream(self, tmp_path): + fifo_path = tmp_path / "fifo" + legacy_file_path = os.path.join( + os.path.dirname(__file__), "data", "hdf5-formats", "msprime-0.3.0_v2.0.hdf5" + ) + write_and_read_from_fifo( + fifo_path, legacy_file_path, tskit.FileFormatError, "not in kastore format" + ) + + def test_tszip_stream(self, tmp_path, ts_fixture): + fifo_path = tmp_path / "fifo" + zip_file_path = tmp_path / "tszip_file" + tszip.compress(ts_fixture, zip_file_path) + write_and_read_from_fifo( + fifo_path, zip_file_path, tskit.FileFormatError, "not in kastore format" + ) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index eb3b90ffd5..6e8be4eef8 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3197,12 +3197,18 @@ def __getstate__(self): def load(cls, file_or_path, *, skip_tables=False, skip_reference_sequence=False): file, local_file = util.convert_file_like_to_open_file(file_or_path, "rb") ll_tc = _tskit.TableCollection() - ll_tc.load( - file, - skip_tables=skip_tables, - skip_reference_sequence=skip_reference_sequence, - ) - return TableCollection(ll_tables=ll_tc) + try: + ll_tc.load( + file, + skip_tables=skip_tables, + skip_reference_sequence=skip_reference_sequence, + ) + return TableCollection(ll_tables=ll_tc) + except tskit.FileFormatError as e: + util.raise_known_file_format_errors(file, e) + finally: + if local_file: + file.close() def dump(self, file_or_path): """ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 0b55882a68..c7cdc2cfa6 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -4093,6 +4093,8 @@ def load(cls, file_or_path, *, skip_tables=False, skip_reference_sequence=False) skip_reference_sequence=skip_reference_sequence, ) return TreeSequence(ts) + except tskit.FileFormatError as e: + util.raise_known_file_format_errors(file, e) finally: if local_file: file.close() diff --git a/python/tskit/util.py b/python/tskit/util.py index 86761a181d..72f08499d6 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -23,6 +23,7 @@ Module responsible for various utility functions used in other modules. """ import dataclasses +import io import itertools import json import numbers @@ -749,3 +750,32 @@ def random_nucleotides(length: numbers.Number, *, seed: Union[int, None] = None) encoded_nucleotides = np.array(list(map(ord, "ACTG")), dtype=np.int8) a = rng.choice(encoded_nucleotides, size=int(length)) return a.tobytes().decode("ascii") + + +def raise_known_file_format_errors(open_file, existing_exception): + """ + Sniffs the file for pk-zip or hdf header bytes, then raises an exception + if these are detected, if not raises the existing exception. + """ + # Check for HDF5 header bytes + try: + open_file.seek(0) + header = open_file.read(4) + except io.UnsupportedOperation: + # If we can't seek, we can't sniff the file. + raise existing_exception + if header == b"\x89HDF": + raise tskit.FileFormatError( + "The specified file appears to be in HDF5 format. This file " + "may have been generated by msprime < 0.6.0 (June 2018) which " + "can no longer be read directly. Please convert to the new " + "kastore format using the ``tskit upgrade`` command." + ) from existing_exception + if header[:2] == b"\x50\x4b": + raise tskit.FileFormatError( + "The specified file appears to be in zip format, so may be a compressed " + "tree sequence. Try using the tszip module to decompress this file before " + "loading. `pip install tszip; tsunzip ` or use " + "`tszip.decompress` in Python code." + ) from existing_exception + raise existing_exception From 38fbc2e9f5950917f977ddb81909d1e1cc0c039d Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 13 Jan 2023 14:45:21 +0000 Subject: [PATCH 23/84] Release prep --- python/CHANGELOG.rst | 2 +- python/tskit/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index c593ba2ef9..477fa841b1 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,5 +1,5 @@ -------------------- -[0.5.4] - 2022-XX-XX +[0.5.4] - 2023-01-13 -------------------- **Features** diff --git a/python/tskit/_version.py b/python/tskit/_version.py index d730ceabab..c29fc15326 100644 --- a/python/tskit/_version.py +++ b/python/tskit/_version.py @@ -1,4 +1,4 @@ # Definitive location for the version number. # During development, should be x.y.z.devN # For beta should be x.y.zbN -tskit_version = "0.5.3" +tskit_version = "0.5.4" From db7423df37d486947a0aae622c7cc4c2e1bf6cf6 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 13 Jan 2023 18:24:48 +0000 Subject: [PATCH 24/84] Fix wheel building for 3.11 --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 81ab28f01e..caa1c0837c 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -127,7 +127,7 @@ jobs: - name: Build wheels in docker shell: bash run: | - docker run --rm -v `pwd`:/project -w /project quay.io/pypa/manylinux2010_x86_64 bash .github/workflows/docker/buildwheel.sh + docker run --rm -v `pwd`:/project -w /project quay.io/pypa/manylinux2014_x86_64 bash .github/workflows/docker/buildwheel.sh - name: Upload Wheels uses: actions/upload-artifact@v2 From 4bad5ec2b7bff77d24e605e569330e845a2908c0 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 13 Jan 2023 18:53:07 +0000 Subject: [PATCH 25/84] Add 311 changelog --- python/CHANGELOG.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 477fa841b1..31756c3ab2 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -29,6 +29,9 @@ be identical to the version in msprime, apart from minor changes in the formatting of tabular text output (:user:`hyanwong`, :user:`jeromekelleher`, :pr:`2678`) +- Tskit now supports and has wheels for Python 3.11. This Python version has a significant + performance boost (:user:`benjeffery`, :pr:`2624`, :issue:`2248`) + **Breaking Changes** - the ``filter_populations``, ``filter_individuals``, and ``filter_sites`` From ad4dd4a12d38cb72fb2dfb26bb68ba89337c66d9 Mon Sep 17 00:00:00 2001 From: chriscrsmith Date: Mon, 16 Jan 2023 11:49:42 -0800 Subject: [PATCH 26/84] implementing Variant.__repr__ added tests for contents edit changelog fix changelog entry --- python/CHANGELOG.rst | 10 ++++++++++ python/tests/test_genotypes.py | 16 +++++++++++++++- python/tskit/genotypes.py | 13 ++++++++++++- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 31756c3ab2..0d1586a854 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,3 +1,13 @@ +-------------------- +[0.5.5] - 2023-01-XX +-------------------- + +**Features** + +- Add ``__repr__`` for variants to return a string representation of the raw data + without spewing megabytes of text (:user:`chriscrsmith`, :pr:`2695`, :issue:`2694`) + + -------------------- [0.5.4] - 2023-01-13 -------------------- diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 1443e40061..329867b600 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2022 Tskit Developers +# Copyright (c) 2019-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -2215,3 +2215,17 @@ def test_variant_html_repr_no_site(self): html = v._repr_html_() ElementTree.fromstring(html) assert len(html) > 1600 + + def test_variant_repr(self, ts_fixture): + v = next(ts_fixture.variants()) + str_rep = repr(v) + assert len(str_rep) > 0 and len(str_rep) < 10000 + assert re.search(r"\AVariant", str_rep) + assert re.search(rf"\'site\': Site\(id={v.site.id}", str_rep) + assert re.search(rf"position={v.position}", str_rep) + alleles = re.escape("'alleles': " + str(v.alleles)) + assert re.search(rf"{alleles}", str_rep) + assert re.search(r"\'genotypes\': array\(\[", str_rep) + assert re.search(rf"position={v.position}", str_rep) + assert re.search(rf"\'has_missing_data\': {v.has_missing_data}", str_rep) + assert re.search(rf"\'isolated_as_missing\': {v.isolated_as_missing}", str_rep) diff --git a/python/tskit/genotypes.py b/python/tskit/genotypes.py index d0abfb3835..239e135777 100644 --- a/python/tskit/genotypes.py +++ b/python/tskit/genotypes.py @@ -1,7 +1,7 @@ # # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -340,6 +340,17 @@ def _repr_html_(self) -> str: """ return util.variant_html(self) + def __repr__(self): + d = { + "site": self.site, + "samples": self.samples, + "alleles": self.alleles, + "genotypes": self.genotypes, + "has_missing_data": self.has_missing_data, + "isolated_as_missing": self.isolated_as_missing, + } + return f"Variant({repr(d)})" + # # Miscellaneous auxiliary methods. From fa048833e0d5953f3d8b912eefbea2abc3f25557 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 19 Jan 2023 14:29:08 +0000 Subject: [PATCH 27/84] Freeze docs environment --- .github/workflows/docs.yml | 40 ++--- python/requirements/CI-docs/requirements.txt | 155 ++++++++++++++++++- 2 files changed, 169 insertions(+), 26 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 0a0490c8ea..c3c3523e13 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.6.0 + uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} @@ -30,34 +30,36 @@ jobs: - uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" + + - name: Install apt deps + if: env.APTGET + run: sudo apt-get install -y ${{env.APTGET}} - uses: actions/cache@v3 - id: cache + id: venv-cache with: path: venv - key: docs-venv-v7-${{ hashFiles(env.REQUIREMENTS) }} - - - name: Build virtualenv - if: steps.cache.outputs.cache-hit != 'true' - run: python -m venv venv - - - name: Downgrade pip - run: venv/bin/activate && pip install pip==20.0.2 - - - name: Install deps - run: venv/bin/activate && pip install -r ${{env.REQUIREMENTS}} + key: docs-venv-v1-${{ hashFiles(env.REQUIREMENTS) }} - - name: Install apt deps - if: env.APTGET - run: sudo apt-get install -y ${{env.APTGET}} + - name: Create venv and install deps (one by one to avoid conflict errors) + if: steps.venv-cache.outputs.cache-hit != 'true' + run: | + python -m venv venv + . venv/bin/activate + pip install --upgrade pip wheel + cat ${{env.REQUIREMENTS}} | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 pip install --no-dependencies - name: Build C module if: env.MAKE_TARGET - run: venv/bin/activate && make $MAKE_TARGET + run: | + . venv/bin/activate + make $MAKE_TARGET - name: Build Docs - run: venv/bin/activate && make -C docs + run: | + . venv/bin/activate + make -C docs - name: Trigger docs site rebuild if: github.ref == 'refs/heads/main' diff --git a/python/requirements/CI-docs/requirements.txt b/python/requirements/CI-docs/requirements.txt index 6278cfce2d..6569b3775a 100644 --- a/python/requirements/CI-docs/requirements.txt +++ b/python/requirements/CI-docs/requirements.txt @@ -1,13 +1,154 @@ +# Due to issues with indirect dependencies introducing conflicting dependencies +# we freeze everything to get a reproducible build. +alabaster==0.7.13 +anyio==3.6.2 +argon2-cffi==21.3.0 +argon2-cffi-bindings==21.2.0 +arrow==1.2.3 +asttokens==2.2.1 +attrs==21.4.0 +Babel==2.11.0 +backcall==0.2.0 +beautifulsoup4==4.11.1 +bleach==5.0.1 breathe==4.34.0 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==3.0.1 +click==8.1.3 +colorama==0.4.6 +comm==0.1.2 +debugpy==1.6.5 +decorator==5.1.1 +defusedxml==0.7.1 +demes==0.2.2 +Deprecated==1.2.13 +docutils==0.17.1 +entrypoints==0.4 +executing==1.2.0 +fastjsonschema==2.16.2 +fqdn==1.5.1 +gitdb==4.0.10 +GitPython==3.1.30 +greenlet==2.0.1 +idna==3.4 +imagesize==1.4.1 +importlib-metadata==6.0.0 +ipykernel==6.20.2 +ipython==8.8.0 +ipython-genutils==0.2.0 +ipywidgets==7.7.2 +isoduration==20.11.0 +jedi==0.18.2 +Jinja2==3.1.2 +jsonpointer==2.3 +jsonschema==4.17.3 jupyter-book==0.13.1 -h5py==3.7.0 -jsonschema[format-nongpl]==4.17.3 +jupyter-cache==0.4.3 +jupyter-events==0.6.3 +jupyter-server-mathjax==0.2.6 +jupyter-sphinx==0.3.2 +jupyter_client==7.4.9 +jupyter_core==5.1.3 +jupyter_server==2.1.0 +jupyter_server_terminals==0.4.4 +jupyterlab-pygments==0.2.2 +jupyterlab-widgets==1.1.1 +latexcodec==2.0.1 +linkify-it-py==1.0.3 +lxml==4.9.2 +markdown-it-py==1.1.0 +MarkupSafe==2.1.2 +matplotlib-inline==0.1.6 +mdit-py-plugins==0.2.8 +mistune==0.8.4 msprime==1.2.0 -numpy==1.21.6 # Held at 1.21.6 for Python 3.7 compatibility -PyGithub==1.55 -sphinx-argparse==0.3.1 -sphinx-autodoc-typehints==1.18.3 # Held at 1.18.3 as that depends on sphinx>=5.2.1 while jupyter-book 0.13.1 depends on sphinx<5 +myst-nb==0.13.2 +myst-parser==0.15.2 +nbclassic==0.4.8 +nbclient==0.5.13 +nbconvert==6.5.4 +nbdime==3.1.1 +nbformat==5.7.3 +nest-asyncio==1.5.6 +newick==1.6.0 +notebook==6.5.2 +notebook_shim==0.2.2 +numpy==1.24.1 +packaging==23.0 +pandocfilters==1.5.0 +parso==0.8.3 +pbr==5.11.1 +pexpect==4.8.0 +pickleshare==0.7.5 +platformdirs==2.6.2 +prometheus-client==0.15.0 +prompt-toolkit==3.0.36 +psutil==5.9.4 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pybtex==0.24.0 +pybtex-docutils==1.0.2 +pycparser==2.21 +pydata-sphinx-theme==0.8.1 +PyGithub==1.57 +Pygments==2.14.0 +PyJWT==2.6.0 +PyNaCl==1.5.0 +pyrsistent==0.19.3 +python-dateutil==2.8.2 +python-json-logger==2.0.4 +pytz==2022.7.1 +PyYAML==6.0 +pyzmq==25.0.0 +requests==2.28.2 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +ruamel.yaml==0.17.21 +ruamel.yaml.clib==0.2.7 +Send2Trash==1.8.0 +six==1.16.0 +smmap==5.0.0 +sniffio==1.3.0 +snowballstemmer==2.2.0 +soupsieve==2.3.2.post1 +Sphinx==4.5.0 +sphinx-argparse==0.4.0 +sphinx-autodoc-typehints==1.19.1 +sphinx-book-theme==0.3.3 +sphinx-comments==0.0.3 +sphinx-copybutton==0.5.1 +sphinx-external-toc==0.2.4 sphinx-issues==3.0.1 +sphinx-jupyterbook-latex==0.4.7 +sphinx-multitoc-numbering==0.1.3 +sphinx-thebe==0.1.2 +sphinx-togglebutton==0.3.2 +sphinx_design==0.1.0 +sphinxcontrib-bibtex==2.5.0 +sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-jsmath==1.0.1 sphinxcontrib-prettyspecialmethods==0.1.0 +sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib.applehelp==1.0.3 +SQLAlchemy==1.4.46 +stack-data==0.6.2 svgwrite==1.4.3 -tskit-book-theme==0.3.2 \ No newline at end of file +terminado==0.17.1 +tinycss2==1.2.1 +tornado==6.2 +traitlets==5.8.1 +tskit==0.5.4 +tskit-book-theme==0.3.2 +uc-micro-py==1.0.1 +uri-template==1.2.0 +urllib3==1.26.14 +wcwidth==0.2.6 +webcolors==1.12 +webencodings==0.5.1 +websocket-client==1.4.2 +widgetsnbextension==3.6.1 +wrapt==1.14.1 +zipp==3.11.0 From 31344ff4aa3e5346e57a24b7114d0cbea3c7f601 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 20 Jan 2023 12:55:55 +0000 Subject: [PATCH 28/84] Add missing changelog --- c/CHANGELOG.rst | 4 ++++ python/CHANGELOG.rst | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index 7cf7d42507..f314a6b3d3 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -14,6 +14,10 @@ nodes be kept in the output (:user:`jeromekelleher`, :user:`hyanwong`, :issue:`2606`, :pr:`2619`). +- Add the `TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option to simplify which ensures + no node sample flags are changed to allow calling code to manage sample status. + (:user:`jeromekelleher`, :issue:`2662`, :pr:`2663`). + - Guarantee that unfiltered tables are not written to unnecessarily during simplify (:user:`jeromekelleher` :pr:`2619`). diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 0d1586a854..a3738ea06b 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -42,6 +42,10 @@ - Tskit now supports and has wheels for Python 3.11. This Python version has a significant performance boost (:user:`benjeffery`, :pr:`2624`, :issue:`2248`) +- Add the `update_sample_flags` option to `simplify` which ensures + no node sample flags are changed to allow calling code to manage sample status. + (:user:`jeromekelleher`, :issue:`2662`, :pr:`2663`). + **Breaking Changes** - the ``filter_populations``, ``filter_individuals``, and ``filter_sites`` From 7abb51e24f1f3f590c5dbcb7382997b3536db22c Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 26 Jan 2023 13:18:22 +0000 Subject: [PATCH 29/84] Update python version number for development --- python/tskit/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tskit/_version.py b/python/tskit/_version.py index c29fc15326..71f3c3e882 100644 --- a/python/tskit/_version.py +++ b/python/tskit/_version.py @@ -1,4 +1,4 @@ # Definitive location for the version number. # During development, should be x.y.z.devN # For beta should be x.y.zbN -tskit_version = "0.5.4" +tskit_version = "0.5.5.dev0" From 9025f5769759e2e5cfbc6b6b0bf218426990a841 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 27 Jan 2023 09:13:41 +0000 Subject: [PATCH 30/84] Correct doc typo --- python/tskit/trees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index c7cdc2cfa6..39a1e4c41a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -635,7 +635,7 @@ class Tree: :param list tracked_samples: The list of samples to be tracked and counted using the :meth:`Tree.num_tracked_samples` method. :param bool sample_lists: If True, provide more efficient access - to the samples beneath a give node using the + to the samples beneath a given node using the :meth:`Tree.samples` method. :param int root_threshold: The minimum number of samples that a node must be ancestral to for it to be in the list of roots. By default From 88fe00ded2df18741114307fd9621e3e89d6f03f Mon Sep 17 00:00:00 2001 From: "Kevin R. Thornton" Date: Thu, 12 Jan 2023 13:14:07 -0800 Subject: [PATCH 31/84] Refactor edge difference iterator to work from table collections as well as from tree sequences. --- c/tests/test_tables.c | 34 +++++++- c/tests/test_trees.c | 4 +- c/tskit/haplotype_matching.c | 4 +- c/tskit/tables.c | 164 ++++++++++++++++++++++++++++++++++- c/tskit/tables.h | 47 ++++++++++ c/tskit/trees.c | 161 ++-------------------------------- c/tskit/trees.h | 43 +-------- 7 files changed, 258 insertions(+), 199 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 1b99ac6fe3..3dd3697965 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,7 @@ */ #include "testlib.h" +#include "tskit/core.h" #include #include @@ -10420,6 +10421,35 @@ test_table_collection_delete_older(void) tsk_treeseq_free(&ts); } +static void +test_table_collection_edge_diffs_errors(void) +{ + int ret; + tsk_id_t ret_id; + tsk_table_collection_t tables; + tsk_diff_iter_t iter; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL(ret, 0); + tables.sequence_length = 1; + ret_id + = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0, -1, -1, NULL, 0); + CU_ASSERT_EQUAL(ret_id, 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1, -1, -1, NULL, 0); + CU_ASSERT_EQUAL(ret_id, 1); + ret = (int) tsk_edge_table_add_row(&tables.edges, 0., 1., 1, 0, NULL, 0); + CU_ASSERT_EQUAL(ret, 0); + + ret = tsk_diff_iter_init(&iter, &tables, -1, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_TABLES_NOT_INDEXED); + + tables.nodes.time[0] = 1; + ret = tsk_diff_iter_init(&iter, &tables, -1, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_NODE_TIME_ORDERING); + + tsk_table_collection_free(&tables); +} + int main(int argc, char **argv) { @@ -10542,6 +10572,8 @@ main(int argc, char **argv) { "test_table_collection_takeset_indexes", test_table_collection_takeset_indexes }, { "test_table_collection_delete_older", test_table_collection_delete_older }, + { "test_table_collection_edge_diffs_errors", + test_table_collection_edge_diffs_errors }, { NULL, NULL }, }; diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index af6e7c644f..f0ced8585f 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -407,7 +407,7 @@ verify_tree_diffs(tsk_treeseq_t *ts, tsk_flags_t options) child[j] = TSK_NULL; sib[j] = TSK_NULL; } - ret = tsk_diff_iter_init(&iter, ts, options); + ret = tsk_diff_iter_init_from_ts(&iter, ts, options); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_init(&tree, ts, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); diff --git a/c/tskit/haplotype_matching.c b/c/tskit/haplotype_matching.c index 41c1bd23a0..b942da18d6 100644 --- a/c/tskit/haplotype_matching.c +++ b/c/tskit/haplotype_matching.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -250,7 +250,7 @@ tsk_ls_hmm_reset(tsk_ls_hmm_t *self) /* This is safe because we've already zero'd out the memory. */ tsk_diff_iter_free(&self->diffs); - ret = tsk_diff_iter_init(&self->diffs, self->tree_sequence, false); + ret = tsk_diff_iter_init_from_ts(&self->diffs, self->tree_sequence, false); if (ret != 0) { goto out; } diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 3039dbbb38..1e910bca69 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2017-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -13010,3 +13010,165 @@ tsk_squash_edges(tsk_edge_t *edges, tsk_size_t num_edges, tsk_size_t *num_output out: return ret; } + +/* ======================================================== * + * Tree diff iterator. + * ======================================================== */ + +int TSK_WARN_UNUSED +tsk_diff_iter_init(tsk_diff_iter_t *self, const tsk_table_collection_t *tables, + tsk_id_t num_trees, tsk_flags_t options) +{ + int ret = 0; + + tsk_bug_assert(tables != NULL); + tsk_memset(self, 0, sizeof(tsk_diff_iter_t)); + self->num_nodes = tables->nodes.num_rows; + self->num_edges = tables->edges.num_rows; + self->tables = tables; + self->insertion_index = 0; + self->removal_index = 0; + self->tree_left = 0; + self->tree_index = -1; + if (num_trees < 0) { + num_trees = tsk_table_collection_check_integrity(self->tables, TSK_CHECK_TREES); + if (num_trees < 0) { + ret = (int) num_trees; + goto out; + } + } + self->last_index = num_trees; + + if (options & TSK_INCLUDE_TERMINAL) { + self->last_index = self->last_index + 1; + } + self->edge_list_nodes = tsk_malloc(self->num_edges * sizeof(*self->edge_list_nodes)); + if (self->edge_list_nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +int +tsk_diff_iter_free(tsk_diff_iter_t *self) +{ + tsk_safe_free(self->edge_list_nodes); + return 0; +} + +void +tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out) +{ + fprintf(out, "tree_diff_iterator state\n"); + fprintf(out, "num_edges = %lld\n", (long long) self->num_edges); + fprintf(out, "insertion_index = %lld\n", (long long) self->insertion_index); + fprintf(out, "removal_index = %lld\n", (long long) self->removal_index); + fprintf(out, "tree_left = %f\n", self->tree_left); + fprintf(out, "tree_index = %lld\n", (long long) self->tree_index); +} + +int TSK_WARN_UNUSED +tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, + tsk_edge_list_t *edges_out_ret, tsk_edge_list_t *edges_in_ret) +{ + int ret = 0; + tsk_id_t k; + const double sequence_length = self->tables->sequence_length; + double left = self->tree_left; + double right = sequence_length; + tsk_size_t next_edge_list_node = 0; + tsk_edge_list_node_t *out_head = NULL; + tsk_edge_list_node_t *out_tail = NULL; + tsk_edge_list_node_t *in_head = NULL; + tsk_edge_list_node_t *in_tail = NULL; + tsk_edge_list_node_t *w = NULL; + tsk_edge_list_t edges_out; + tsk_edge_list_t edges_in; + const tsk_edge_table_t *edges = &self->tables->edges; + const tsk_id_t *insertion_order = self->tables->indexes.edge_insertion_order; + const tsk_id_t *removal_order = self->tables->indexes.edge_removal_order; + + tsk_memset(&edges_out, 0, sizeof(edges_out)); + tsk_memset(&edges_in, 0, sizeof(edges_in)); + + if (self->tree_index + 1 < self->last_index) { + /* First we remove the stale records */ + while (self->removal_index < (tsk_id_t) self->num_edges + && left == edges->right[removal_order[self->removal_index]]) { + k = removal_order[self->removal_index]; + tsk_bug_assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.id = k; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->edge.metadata = edges->metadata + edges->metadata_offset[k]; + w->edge.metadata_length + = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; + w->next = NULL; + w->prev = NULL; + if (out_head == NULL) { + out_head = w; + out_tail = w; + } else { + out_tail->next = w; + w->prev = out_tail; + out_tail = w; + } + self->removal_index++; + } + edges_out.head = out_head; + edges_out.tail = out_tail; + + /* Now insert the new records */ + while (self->insertion_index < (tsk_id_t) self->num_edges + && left == edges->left[insertion_order[self->insertion_index]]) { + k = insertion_order[self->insertion_index]; + tsk_bug_assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.id = k; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->edge.metadata = edges->metadata + edges->metadata_offset[k]; + w->edge.metadata_length + = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; + w->next = NULL; + w->prev = NULL; + if (in_head == NULL) { + in_head = w; + in_tail = w; + } else { + in_tail->next = w; + w->prev = in_tail; + in_tail = w; + } + self->insertion_index++; + } + edges_in.head = in_head; + edges_in.tail = in_tail; + + right = sequence_length; + if (self->insertion_index < (tsk_id_t) self->num_edges) { + right = TSK_MIN(right, edges->left[insertion_order[self->insertion_index]]); + } + if (self->removal_index < (tsk_id_t) self->num_edges) { + right = TSK_MIN(right, edges->right[removal_order[self->removal_index]]); + } + self->tree_index++; + ret = TSK_TREE_OK; + } + *edges_out_ret = edges_out; + *edges_in_ret = edges_in; + *ret_left = left; + *ret_right = right; + /* Set the left coordinate for the next tree */ + self->tree_left = right; + return ret; +} diff --git a/c/tskit/tables.h b/c/tskit/tables.h index a3496bb9ef..321a675271 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -670,6 +670,30 @@ typedef struct { bool store_pairs; } tsk_identity_segments_t; +/* Diff iterator. */ +typedef struct _tsk_edge_list_node_t { + tsk_edge_t edge; + struct _tsk_edge_list_node_t *next; + struct _tsk_edge_list_node_t *prev; +} tsk_edge_list_node_t; + +typedef struct { + tsk_edge_list_node_t *head; + tsk_edge_list_node_t *tail; +} tsk_edge_list_t; + +typedef struct { + tsk_size_t num_nodes; + tsk_size_t num_edges; + double tree_left; + const tsk_table_collection_t *tables; + tsk_id_t insertion_index; + tsk_id_t removal_index; + tsk_id_t tree_index; + tsk_id_t last_index; + tsk_edge_list_node_t *edge_list_nodes; +} tsk_diff_iter_t; + /****************************************************************************/ /* Common function options */ /****************************************************************************/ @@ -892,6 +916,16 @@ top-level information of the table collections being compared. #define TSK_CLEAR_PROVENANCE (1 << 2) /** @} */ +/* For the edge diff iterator */ +#define TSK_INCLUDE_TERMINAL (1 << 0) + +/** @brief Value returned by seeking methods when they have successfully + seeked to a non-null tree. + + @ingroup TREE_API_SEEKING_GROUP +*/ +#define TSK_TREE_OK 1 + /****************************************************************************/ /* Function signatures */ /****************************************************************************/ @@ -4417,6 +4451,19 @@ int tsk_identity_segments_get(const tsk_identity_segments_t *self, tsk_id_t a, void tsk_identity_segments_print_state(tsk_identity_segments_t *self, FILE *out); int tsk_identity_segments_free(tsk_identity_segments_t *self); +/* Edge differences */ + +/* Internal API - currently used in a few places, but a better API is envisaged + * at some point. + * IMPORTANT: tskit-rust uses this API, so don't break without discussing! + */ +int tsk_diff_iter_init(tsk_diff_iter_t *self, const tsk_table_collection_t *tables, + tsk_id_t num_trees, tsk_flags_t options); +int tsk_diff_iter_free(tsk_diff_iter_t *self); +int tsk_diff_iter_next(tsk_diff_iter_t *self, double *left, double *right, + tsk_edge_list_t *edges_out, tsk_edge_list_t *edges_in); +void tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out); + #ifdef __cplusplus } #endif diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 4fcb2ee376..8d202d0163 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -5344,159 +5344,16 @@ tsk_tree_map_mutations(tsk_tree_t *self, int32_t *genotypes, return ret; } -/* ======================================================== * - * Tree diff iterator. - * ======================================================== */ - +/* Compatibility shim for initialising the diff iterator from a tree sequence. We are + * using this function in a small number of places internally, so simplest to keep it + * until a more satisfactory "diff" API comes along. + */ int TSK_WARN_UNUSED -tsk_diff_iter_init( +tsk_diff_iter_init_from_ts( tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) { - int ret = 0; - - tsk_bug_assert(tree_sequence != NULL); - tsk_memset(self, 0, sizeof(tsk_diff_iter_t)); - self->num_nodes = tsk_treeseq_get_num_nodes(tree_sequence); - self->num_edges = tsk_treeseq_get_num_edges(tree_sequence); - self->tree_sequence = tree_sequence; - self->insertion_index = 0; - self->removal_index = 0; - self->tree_left = 0; - self->tree_index = -1; - self->last_index = (tsk_id_t) tsk_treeseq_get_num_trees(tree_sequence); - if (options & TSK_INCLUDE_TERMINAL) { - self->last_index = self->last_index + 1; - } - self->edge_list_nodes = tsk_malloc(self->num_edges * sizeof(*self->edge_list_nodes)); - if (self->edge_list_nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } -out: - return ret; -} - -int -tsk_diff_iter_free(tsk_diff_iter_t *self) -{ - tsk_safe_free(self->edge_list_nodes); - return 0; -} - -void -tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out) -{ - fprintf(out, "tree_diff_iterator state\n"); - fprintf(out, "num_edges = %lld\n", (long long) self->num_edges); - fprintf(out, "insertion_index = %lld\n", (long long) self->insertion_index); - fprintf(out, "removal_index = %lld\n", (long long) self->removal_index); - fprintf(out, "tree_left = %f\n", self->tree_left); - fprintf(out, "tree_index = %lld\n", (long long) self->tree_index); -} - -int TSK_WARN_UNUSED -tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, - tsk_edge_list_t *edges_out_ret, tsk_edge_list_t *edges_in_ret) -{ - int ret = 0; - tsk_id_t k; - const double sequence_length = self->tree_sequence->tables->sequence_length; - double left = self->tree_left; - double right = sequence_length; - tsk_size_t next_edge_list_node = 0; - const tsk_treeseq_t *s = self->tree_sequence; - tsk_edge_list_node_t *out_head = NULL; - tsk_edge_list_node_t *out_tail = NULL; - tsk_edge_list_node_t *in_head = NULL; - tsk_edge_list_node_t *in_tail = NULL; - tsk_edge_list_node_t *w = NULL; - tsk_edge_list_t edges_out; - tsk_edge_list_t edges_in; - const tsk_edge_table_t *edges = &s->tables->edges; - const tsk_id_t *insertion_order = s->tables->indexes.edge_insertion_order; - const tsk_id_t *removal_order = s->tables->indexes.edge_removal_order; - - tsk_memset(&edges_out, 0, sizeof(edges_out)); - tsk_memset(&edges_in, 0, sizeof(edges_in)); - - if (self->tree_index + 1 < self->last_index) { - /* First we remove the stale records */ - while (self->removal_index < (tsk_id_t) self->num_edges - && left == edges->right[removal_order[self->removal_index]]) { - k = removal_order[self->removal_index]; - tsk_bug_assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->edge.metadata = edges->metadata + edges->metadata_offset[k]; - w->edge.metadata_length - = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; - w->next = NULL; - w->prev = NULL; - if (out_head == NULL) { - out_head = w; - out_tail = w; - } else { - out_tail->next = w; - w->prev = out_tail; - out_tail = w; - } - self->removal_index++; - } - edges_out.head = out_head; - edges_out.tail = out_tail; - - /* Now insert the new records */ - while (self->insertion_index < (tsk_id_t) self->num_edges - && left == edges->left[insertion_order[self->insertion_index]]) { - k = insertion_order[self->insertion_index]; - tsk_bug_assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->edge.metadata = edges->metadata + edges->metadata_offset[k]; - w->edge.metadata_length - = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; - w->next = NULL; - w->prev = NULL; - if (in_head == NULL) { - in_head = w; - in_tail = w; - } else { - in_tail->next = w; - w->prev = in_tail; - in_tail = w; - } - self->insertion_index++; - } - edges_in.head = in_head; - edges_in.tail = in_tail; - - right = sequence_length; - if (self->insertion_index < (tsk_id_t) self->num_edges) { - right = TSK_MIN(right, edges->left[insertion_order[self->insertion_index]]); - } - if (self->removal_index < (tsk_id_t) self->num_edges) { - right = TSK_MIN(right, edges->right[removal_order[self->removal_index]]); - } - self->tree_index++; - ret = TSK_TREE_OK; - } - *edges_out_ret = edges_out; - *edges_in_ret = edges_in; - *ret_left = left; - *ret_right = right; - /* Set the left coordinate for the next tree */ - self->tree_left = right; - return ret; + return tsk_diff_iter_init( + self, tree_sequence->tables, (tsk_id_t) tree_sequence->num_trees, options); } /* ======================================================== * @@ -5927,7 +5784,7 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_init(&diff_iters[i], treeseqs[i], false); + ret = tsk_diff_iter_init_from_ts(&diff_iters[i], treeseqs[i], false); if (ret != 0) { goto out; } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index cae952dee3..4a84bf3446 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -59,9 +59,6 @@ extern "C" { #define TSK_DIR_FORWARD 1 #define TSK_DIR_REVERSE -1 -/* For the edge diff iterator */ -#define TSK_INCLUDE_TERMINAL (1 << 0) - /** @defgroup API_FLAGS_TS_INIT_GROUP :c:func:`tsk_treeseq_init` specific flags. @{ @@ -261,30 +258,6 @@ typedef struct { tsk_id_t right_index; } tsk_tree_t; -/* Diff iterator. */ -typedef struct _tsk_edge_list_node_t { - tsk_edge_t edge; - struct _tsk_edge_list_node_t *next; - struct _tsk_edge_list_node_t *prev; -} tsk_edge_list_node_t; - -typedef struct { - tsk_edge_list_node_t *head; - tsk_edge_list_node_t *tail; -} tsk_edge_list_t; - -typedef struct { - tsk_size_t num_nodes; - tsk_size_t num_edges; - double tree_left; - const tsk_treeseq_t *tree_sequence; - tsk_id_t insertion_index; - tsk_id_t removal_index; - tsk_id_t tree_index; - tsk_id_t last_index; - tsk_edge_list_node_t *edge_list_nodes; -} tsk_diff_iter_t; - /****************************************************************************/ /* Tree sequence.*/ /****************************************************************************/ @@ -1114,10 +1087,6 @@ int tsk_tree_copy(const tsk_tree_t *self, tsk_tree_t *dest, tsk_flags_t options) @{ */ -/** @brief Value returned by seeking methods when they have successfully - seeked to a non-null tree. */ -#define TSK_TREE_OK 1 - /** @brief Seek to the first tree in the sequence. @@ -1742,16 +1711,8 @@ bool tsk_tree_is_sample(const tsk_tree_t *self, tsk_id_t u); */ bool tsk_tree_equals(const tsk_tree_t *self, const tsk_tree_t *other); -/****************************************************************************/ -/* Diff iterator */ -/****************************************************************************/ - -int tsk_diff_iter_init( +int tsk_diff_iter_init_from_ts( tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); -int tsk_diff_iter_free(tsk_diff_iter_t *self); -int tsk_diff_iter_next(tsk_diff_iter_t *self, double *left, double *right, - tsk_edge_list_t *edges_out, tsk_edge_list_t *edges_in); -void tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out); #ifdef __cplusplus } From 0efc3f1cbeab3013310783c8d85d22d1480e434e Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 23 Feb 2023 00:24:02 +0000 Subject: [PATCH 32/84] Clear docs CI cache --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c3c3523e13..27d11ef5fe 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -40,7 +40,7 @@ jobs: id: venv-cache with: path: venv - key: docs-venv-v1-${{ hashFiles(env.REQUIREMENTS) }} + key: docs-venv-v2-${{ hashFiles(env.REQUIREMENTS) }} - name: Create venv and install deps (one by one to avoid conflict errors) if: steps.venv-cache.outputs.cache-hit != 'true' From 532187d3f8e72882a54a7b53630bf26531edf27c Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 3 Feb 2023 13:21:29 +0000 Subject: [PATCH 33/84] Support bool as sphinx :c:type: --- docs/_config.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/_config.yml b/docs/_config.yml index 600b1705a3..6906a678cd 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -71,11 +71,13 @@ sphinx: ["c:identifier", "uint32_t"], ["c:identifier", "uint64_t"], ["c:identifier", "FILE"], + ["c:identifier", "bool"], # This is for the anonymous interval struct embedded in the tsk_tree_t. ["c:identifier", "tsk_tree_t.@1"], ["c:type", "int32_t"], ["c:type", "uint32_t"], ["c:type", "uint64_t"], + ["c:type", "bool"], # TODO these have been triaged here to make the docs compile, but we should # sort them out properly. https://github.com/tskit-dev/tskit/issues/336 ["py:class", "array_like"], @@ -89,6 +91,18 @@ sphinx: ["py:class", "dtype=np.int64"], ] + # Added to allow "bool" be used as a :ctype: - this list has to be + # manually specifed in order to remove "bool" from it. + c_extra_keywords: [ + "alignas", + "alignof", + "complex", + "imaginary", + "noreturn", + "static_assert", + "thread_local" + ] + autodoc_member_order: bysource # Without this option, autodoc tries to put links for all return types From d13680deeb3370420db42312f54d3acbd66a1fc4 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 15 Dec 2022 15:04:19 +0000 Subject: [PATCH 34/84] Efficient in-place table subsetting Closes #2666 --- c/CHANGELOG.rst | 5 +- c/tests/test_tables.c | 1098 +++++++++++++++++++++++++++++++++++++++++ c/tskit/core.c | 6 +- c/tskit/core.h | 8 +- c/tskit/tables.c | 434 ++++++++++++++++ c/tskit/tables.h | 266 ++++++++++ 6 files changed, 1814 insertions(+), 3 deletions(-) diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index f314a6b3d3..c1643e1ff1 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -1,5 +1,5 @@ -------------------- -[1.1.2] - 2022-XX-XX +[1.1.2] - 2023-XX-XX -------------------- **Features** @@ -21,6 +21,9 @@ - Guarantee that unfiltered tables are not written to unnecessarily during simplify (:user:`jeromekelleher` :pr:`2619`). +- Add `x_table_keep_rows` methods to provide efficient in-place table subsetting + (:user:`jeromekelleher`, :pr:`2700`). + -------------------- [1.1.1] - 2022-07-29 -------------------- diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 3dd3697965..65ff5916ad 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -1491,6 +1491,106 @@ test_node_table_update_row(void) tsk_node_table_free(&table); } +static void +test_node_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_node_table_t source, t1, t2; + tsk_node_t row; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + const char *metadata = "ABC"; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_node_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_node_table_add_row(&source, 0, 1.0, 2, 3, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&source, 1, 2.0, 3, 4, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&source, 2, 3.0, 4, 5, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_node_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_node_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&t1, &source, 0)); + + ret = tsk_node_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_node_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_node_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_node_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_node_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.flags, 1); + CU_ASSERT_EQUAL_FATAL(row.time, 2.0); + CU_ASSERT_EQUAL_FATAL(row.population, 3); + CU_ASSERT_EQUAL_FATAL(row.individual, 4); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_node_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_node_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_node_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_node_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&source, &t2, 0)); + + tsk_node_table_free(&t1); + tsk_node_table_free(&t2); + } + + tsk_node_table_free(&source); +} + static void test_edge_table_with_options(tsk_flags_t options) { @@ -2034,6 +2134,203 @@ test_edge_table_update_row_no_metadata(void) tsk_edge_table_free(&table); } +static void +test_edge_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_edge_table_t source, t1, t2; + tsk_edge_t row; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + const char *metadata = "ABC"; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_edge_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_edge_table_add_row(&source, 0, 1.0, 2, 3, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 1, 2.0, 3, 4, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 2, 3.0, 4, 5, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_edge_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.left, 1); + CU_ASSERT_EQUAL_FATAL(row.right, 2.0); + CU_ASSERT_EQUAL_FATAL(row.parent, 3); + CU_ASSERT_EQUAL_FATAL(row.child, 4); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_edge_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_edge_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_edge_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_edge_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&source, &t2, 0)); + + tsk_edge_table_free(&t1); + tsk_edge_table_free(&t2); + } + + tsk_edge_table_free(&source); +} + +static void +test_edge_table_keep_rows_no_metadata(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_edge_table_t source, t1, t2; + tsk_edge_t row; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_edge_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_edge_table_add_row(&source, 0, 1.0, 2, 3, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 1, 2.0, 3, 4, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 2, 3.0, 4, 5, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_edge_table_copy(&source, &t1, TSK_TABLE_NO_METADATA); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.left, 1); + CU_ASSERT_EQUAL_FATAL(row.right, 2.0); + CU_ASSERT_EQUAL_FATAL(row.parent, 3); + CU_ASSERT_EQUAL_FATAL(row.child, 4); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 0); + + tsk_edge_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_edge_table_copy(&source, &t2, TSK_TABLE_NO_METADATA); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_copy(&source, &t1, TSK_TABLE_NO_METADATA); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_edge_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_edge_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&source, &t2, 0)); + + tsk_edge_table_free(&t1); + tsk_edge_table_free(&t2); + } + + tsk_edge_table_free(&source); +} + static void test_edge_table_takeset_with_options(tsk_flags_t table_options) { @@ -2958,6 +3255,107 @@ test_site_table_update_row(void) tsk_site_table_free(&table); } +static void +test_site_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_site_table_t source, t1, t2; + tsk_site_t row; + const char *ancestral_state = "XYZ"; + const char *metadata = "ABC"; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_site_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_site_table_add_row(&source, 0, ancestral_state, 1, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&source, 1, ancestral_state, 2, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&source, 2, ancestral_state, 3, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_site_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_site_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&t1, &source, 0)); + + ret = tsk_site_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_site_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_site_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_site_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_site_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.position, 1); + CU_ASSERT_EQUAL_FATAL(row.ancestral_state_length, 2); + CU_ASSERT_EQUAL_FATAL(row.ancestral_state[0], 'X'); + CU_ASSERT_EQUAL_FATAL(row.ancestral_state[1], 'Y'); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_site_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_site_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_site_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_site_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_site_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_site_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&source, &t2, 0)); + + tsk_site_table_free(&t1); + tsk_site_table_free(&t2); + } + + tsk_site_table_free(&source); +} + static void test_mutation_table(void) { @@ -3643,6 +4041,199 @@ test_mutation_table_update_row(void) tsk_mutation_table_free(&table); } +static void +test_mutation_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_mutation_table_t source, t1, t2; + tsk_mutation_t row; + const char *derived_state = "XYZ"; + const char *metadata = "ABC"; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_mutation_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_mutation_table_add_row( + &source, 0, 1, -1, 3.0, derived_state, 1, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &source, 1, 2, -1, 4.0, derived_state, 2, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &source, 2, 3, 0, 5.0, derived_state, 3, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_mutation_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1, &source, 0)); + + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_mutation_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_mutation_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.site, 1); + CU_ASSERT_EQUAL_FATAL(row.node, 2); + CU_ASSERT_EQUAL_FATAL(row.parent, -1); + CU_ASSERT_EQUAL_FATAL(row.time, 4); + CU_ASSERT_EQUAL_FATAL(row.derived_state_length, 2); + CU_ASSERT_EQUAL_FATAL(row.derived_state[0], 'X'); + CU_ASSERT_EQUAL_FATAL(row.derived_state[1], 'Y'); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_mutation_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_mutation_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_mutation_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_mutation_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_mutation_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_mutation_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t2, 0)); + + tsk_mutation_table_free(&t1); + tsk_mutation_table_free(&t2); + } + + tsk_mutation_table_free(&source); +} + +static void +test_mutation_table_keep_rows_parent_references(void) +{ + int ret; + tsk_id_t ret_id; + tsk_mutation_table_t source, t; + bool keep[4] = { 1, 1, 1, 1 }; + tsk_id_t id_map[4]; + + ret = tsk_mutation_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_mutation_table_add_row(&source, 0, 1, -1, 3.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row(&source, 1, 2, -1, 4.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row(&source, 2, 3, 1, 5.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row(&source, 3, 4, 1, 6.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* OOB errors */ + t.parent[0] = -2; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + + t.parent[0] = 4; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + /* But ignored if row is not kept */ + keep[0] = false; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_mutation_table_free(&t); + + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* Try to remove referenced row 1 */ + keep[0] = true; + keep[1] = false; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t, 0)); + tsk_mutation_table_free(&t); + + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* remove unreferenced row 0 */ + keep[0] = false; + keep[1] = true; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 3); + CU_ASSERT_EQUAL_FATAL(t.parent[0], TSK_NULL); + CU_ASSERT_EQUAL_FATAL(t.parent[1], 0); + CU_ASSERT_EQUAL_FATAL(t.parent[2], 0); + tsk_mutation_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parent[3] = -2; + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = true; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t, 0)); + tsk_mutation_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parent[3] = 0; + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = false; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t, 0)); + tsk_mutation_table_free(&t); + + tsk_mutation_table_free(&source); +} + static void test_migration_table(void) { @@ -4244,6 +4835,108 @@ test_migration_table_update_row(void) tsk_migration_table_free(&table); } +static void +test_migration_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_migration_table_t source, t1, t2; + tsk_migration_t row; + const char *metadata = "ABC"; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_migration_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_migration_table_add_row(&source, 0, 1.0, 2, 3, 4, 5, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_migration_table_add_row(&source, 1, 2.0, 3, 4, 5, 6, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_migration_table_add_row(&source, 2, 3.0, 4, 5, 6, 7, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_migration_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_migration_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&t1, &source, 0)); + + ret = tsk_migration_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_migration_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_migration_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_migration_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_migration_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.left, 1); + CU_ASSERT_EQUAL_FATAL(row.right, 2); + CU_ASSERT_EQUAL_FATAL(row.node, 3); + CU_ASSERT_EQUAL_FATAL(row.source, 4); + CU_ASSERT_EQUAL_FATAL(row.dest, 5); + CU_ASSERT_EQUAL_FATAL(row.time, 6); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_migration_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_migration_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_migration_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_migration_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_migration_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_migration_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&source, &t2, 0)); + + tsk_migration_table_free(&t1); + tsk_migration_table_free(&t2); + } + + tsk_migration_table_free(&source); +} + static void test_individual_table(void) { @@ -4969,6 +5662,201 @@ test_individual_table_update_row(void) tsk_individual_table_free(&table); } +static void +test_individual_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_individual_t row; + double location[] = { 0, 1, 2 }; + tsk_id_t parents[] = { -1, 1, -1 }; + const char *metadata = "ABC"; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t indexes[] = { 0, 1, 2 }; + tsk_id_t id_map[3]; + tsk_individual_table_t source, t1, t2; + tsk_size_t j; + + ret = tsk_individual_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id + = tsk_individual_table_add_row(&source, 0, location, 1, parents, 1, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id + = tsk_individual_table_add_row(&source, 1, location, 2, parents, 2, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id + = tsk_individual_table_add_row(&source, 2, location, 3, parents, 3, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_individual_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_individual_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&t1, &source, 0)); + + ret = tsk_individual_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_individual_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_individual_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_individual_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_individual_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.flags, 1); + CU_ASSERT_EQUAL_FATAL(row.parents_length, 2); + CU_ASSERT_EQUAL_FATAL(row.parents[0], -1); + CU_ASSERT_EQUAL_FATAL(row.parents[1], 0); + CU_ASSERT_EQUAL_FATAL(row.location_length, 2); + CU_ASSERT_EQUAL_FATAL(row.location[0], 0); + CU_ASSERT_EQUAL_FATAL(row.location[1], 1); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_individual_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_individual_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_individual_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_individual_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_individual_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_individual_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t2, 0)); + + tsk_individual_table_free(&t1); + tsk_individual_table_free(&t2); + } + + tsk_individual_table_free(&source); +} + +static void +test_individual_table_keep_rows_parent_references(void) +{ + int ret; + tsk_id_t ret_id; + tsk_individual_table_t source, t; + bool keep[] = { 1, 1, 1, 1 }; + tsk_id_t parents[] = { -1, 1, 2 }; + tsk_id_t id_map[4]; + + ret = tsk_individual_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 3, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* OOB errors */ + t.parents[0] = -2; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + + t.parents[0] = 4; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + /* But ignored if row is not kept */ + keep[0] = false; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_individual_table_free(&t); + + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* Try to remove referenced row 2 */ + keep[0] = true; + keep[2] = false; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t, 0)); + tsk_individual_table_free(&t); + + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* remove unreferenced row 0 */ + keep[0] = false; + keep[2] = true; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 3); + CU_ASSERT_EQUAL_FATAL(t.parents[0], TSK_NULL); + CU_ASSERT_EQUAL_FATAL(t.parents[1], 0); + CU_ASSERT_EQUAL_FATAL(t.parents[2], 1); + tsk_individual_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parents[1] = -2; + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = true; + ret = tsk_individual_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t, 0)); + tsk_individual_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parents[1] = 0; + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = false; + ret = tsk_individual_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t, 0)); + tsk_individual_table_free(&t); + + tsk_individual_table_free(&source); +} + static void test_population_table(void) { @@ -5346,6 +6234,102 @@ test_population_table_update_row(void) tsk_population_table_free(&table); } +static void +test_population_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_population_table_t source, t1, t2; + tsk_population_t row; + const char *metadata = "ABC"; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_population_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_population_table_add_row(&source, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_population_table_add_row(&source, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_population_table_add_row(&source, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_population_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_population_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&t1, &source, 0)); + + ret = tsk_population_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_population_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_population_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_population_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_population_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_population_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_population_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_population_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_population_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_population_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_population_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&source, &t2, 0)); + + tsk_population_table_free(&t1); + tsk_population_table_free(&t2); + } + + tsk_population_table_free(&source); +} + static void test_provenance_table(void) { @@ -5785,6 +6769,106 @@ test_provenance_table_update_row(void) tsk_provenance_table_free(&table); } +static void +test_provenance_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_provenance_table_t source, t1, t2; + tsk_provenance_t row; + const char *timestamp = "XYZ"; + const char *record = "ABC"; + bool keep[3] = { 1, 1, 1 }; + tsk_id_t indexes[] = { 0, 1, 2 }; + tsk_id_t id_map[3]; + + ret = tsk_provenance_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_provenance_table_add_row(&source, timestamp, 1, record, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_provenance_table_add_row(&source, timestamp, 2, record, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_provenance_table_add_row(&source, timestamp, 3, record, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_provenance_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&t1, &source, 0)); + + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_provenance_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_provenance_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.timestamp_length, 2); + CU_ASSERT_EQUAL_FATAL(row.timestamp[0], 'X'); + CU_ASSERT_EQUAL_FATAL(row.timestamp[1], 'Y'); + CU_ASSERT_EQUAL_FATAL(row.record_length, 2); + CU_ASSERT_EQUAL_FATAL(row.record[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.record[1], 'B'); + + tsk_provenance_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_provenance_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_provenance_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_provenance_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_provenance_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_provenance_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&source, &t2, 0)); + + tsk_provenance_table_free(&t1); + tsk_provenance_table_free(&t2); + } + + tsk_provenance_table_free(&source); +} + static void test_table_size_increments(void) { @@ -10456,11 +11540,15 @@ main(int argc, char **argv) CU_TestInfo tests[] = { { "test_node_table", test_node_table }, { "test_node_table_update_row", test_node_table_update_row }, + { "test_node_table_keep_rows", test_node_table_keep_rows }, { "test_node_table_takeset", test_node_table_takeset }, { "test_edge_table", test_edge_table }, { "test_edge_table_update_row", test_edge_table_update_row }, { "test_edge_table_update_row_no_metadata", test_edge_table_update_row_no_metadata }, + { "test_edge_table_keep_rows", test_edge_table_keep_rows }, + { "test_edge_table_keep_rows_no_metadata", + test_edge_table_keep_rows_no_metadata }, { "test_edge_table_takeset", test_edge_table_takeset }, { "test_edge_table_copy_semantics", test_edge_table_copy_semantics }, { "test_edge_table_squash", test_edge_table_squash }, @@ -10472,21 +11560,31 @@ main(int argc, char **argv) { "test_edge_table_squash_metadata", test_edge_table_squash_metadata }, { "test_site_table", test_site_table }, { "test_site_table_update_row", test_site_table_update_row }, + { "test_site_table_keep_rows", test_site_table_keep_rows }, { "test_site_table_takeset", test_site_table_takeset }, { "test_mutation_table", test_mutation_table }, { "test_mutation_table_update_row", test_mutation_table_update_row }, { "test_mutation_table_takeset", test_mutation_table_takeset }, + { "test_mutation_table_keep_rows", test_mutation_table_keep_rows }, + { "test_mutation_table_keep_rows_parent_references", + test_mutation_table_keep_rows_parent_references }, { "test_migration_table", test_migration_table }, { "test_migration_table_update_row", test_migration_table_update_row }, + { "test_migration_table_keep_rows", test_migration_table_keep_rows }, { "test_migration_table_takeset", test_migration_table_takeset }, { "test_individual_table", test_individual_table }, { "test_individual_table_takeset", test_individual_table_takeset }, { "test_individual_table_update_row", test_individual_table_update_row }, + { "test_individual_table_keep_rows", test_individual_table_keep_rows }, + { "test_individual_table_keep_rows_parent_references", + test_individual_table_keep_rows_parent_references }, { "test_population_table", test_population_table }, { "test_population_table_update_row", test_population_table_update_row }, + { "test_population_table_keep_rows", test_population_table_keep_rows }, { "test_population_table_takeset", test_population_table_takeset }, { "test_provenance_table", test_provenance_table }, { "test_provenance_table_update_row", test_provenance_table_update_row }, + { "test_provenance_table_keep_rows", test_provenance_table_keep_rows }, { "test_provenance_table_takeset", test_provenance_table_takeset }, { "test_table_size_increments", test_table_size_increments }, { "test_table_expansion", test_table_expansion }, diff --git a/c/tskit/core.c b/c/tskit/core.c index bc50a21a5f..b1ea25badd 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -222,6 +222,10 @@ tsk_strerror_internal(int err) case TSK_ERR_SEEK_OUT_OF_BOUNDS: ret = "Tree seek position out of bounds. (TSK_ERR_SEEK_OUT_OF_BOUNDS)"; break; + case TSK_ERR_KEEP_ROWS_MAP_TO_DELETED: + ret = "One of the kept rows in the table refers to a deleted row. " + "(TSK_ERR_KEEP_ROWS_MAP_TO_DELETED)"; + break; /* Edge errors */ case TSK_ERR_NULL_PARENT: diff --git a/c/tskit/core.h b/c/tskit/core.h index 0e7d528b0c..7810a8d048 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -356,6 +356,12 @@ A time value was non-finite (NaN counts as finite) A genomic position was non-finite */ #define TSK_ERR_GENOME_COORDS_NONFINITE -211 +/** +One of the rows in the retained table refers to a row that has been +deleted. +*/ +#define TSK_ERR_KEEP_ROWS_MAP_TO_DELETED -212 + /** @} */ /** diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 1e910bca69..c0ce82bd39 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -732,6 +732,187 @@ write_metadata_schema_header( return fprintf(out, fmt, (int) metadata_schema_length, metadata_schema); } +/* Utilities for in-place subsetting columns */ + +static tsk_size_t +count_true(tsk_size_t num_rows, const bool *restrict keep) +{ + tsk_size_t j; + tsk_size_t count = 0; + + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + count++; + } + } + return count; +} + +static void +keep_mask_to_id_map( + tsk_size_t num_rows, const bool *restrict keep, tsk_id_t *restrict id_map) +{ + tsk_size_t j; + tsk_id_t next_id = 0; + + for (j = 0; j < num_rows; j++) { + id_map[j] = TSK_NULL; + if (keep[j]) { + id_map[j] = next_id; + next_id++; + } + } +} + +static tsk_size_t +subset_remap_id_column(tsk_id_t *restrict column, tsk_size_t num_rows, + const bool *restrict keep, const tsk_id_t *restrict id_map) +{ + tsk_size_t j, k; + tsk_id_t value; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + value = column[j]; + if (value != TSK_NULL) { + value = id_map[value]; + } + column[k] = value; + k++; + } + } + return k; +} + +/* Trigger warning: C++ programmers should look away... This may be one of the + * few cases where some macro funkiness is warranted, as these are exact + * duplicates of the same function with just the type of the column + * parameter changed. */ + +static tsk_size_t +subset_id_column( + tsk_id_t *restrict column, tsk_size_t num_rows, const bool *restrict keep) +{ + tsk_size_t j, k; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + column[k] = column[j]; + k++; + } + } + return k; +} + +static tsk_size_t +subset_flags_column( + tsk_flags_t *restrict column, tsk_size_t num_rows, const bool *restrict keep) +{ + tsk_size_t j, k; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + column[k] = column[j]; + k++; + } + } + return k; +} + +static tsk_size_t +subset_double_column( + double *restrict column, tsk_size_t num_rows, const bool *restrict keep) +{ + tsk_size_t j, k; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + column[k] = column[j]; + k++; + } + } + return k; +} + +static tsk_size_t +subset_ragged_char_column(char *restrict data, tsk_size_t *restrict offset_col, + tsk_size_t num_rows, const bool *restrict keep) +{ + tsk_size_t j, k, i, offset; + + k = 0; + offset = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + offset_col[k] = offset; + /* Note: Unclear whether it's worth calling memcpy instead here? + * Need to be careful since the regions are overlapping */ + for (i = offset_col[j]; i < offset_col[j + 1]; i++) { + data[offset] = data[i]; + offset++; + } + k++; + } + } + offset_col[k] = offset; + return offset; +} + +static tsk_size_t +subset_ragged_double_column(double *restrict data, tsk_size_t *restrict offset_col, + tsk_size_t num_rows, const bool *restrict keep) +{ + tsk_size_t j, k, i, offset; + + k = 0; + offset = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + offset_col[k] = offset; + /* Note: Unclear whether it's worth calling memcpy instead here? + * Need to be careful since the regions are overlapping */ + for (i = offset_col[j]; i < offset_col[j + 1]; i++) { + data[offset] = data[i]; + offset++; + } + k++; + } + } + offset_col[k] = offset; + return offset; +} + +static tsk_size_t +subset_remap_ragged_id_column(tsk_id_t *restrict data, tsk_size_t *restrict offset_col, + tsk_size_t num_rows, const bool *restrict keep, const tsk_id_t *restrict id_map) +{ + tsk_size_t j, k, i, offset; + tsk_id_t di; + + k = 0; + offset = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + offset_col[k] = offset; + for (i = offset_col[j]; i < offset_col[j + 1]; i++) { + di = data[i]; + if (di != TSK_NULL) { + di = id_map[di]; + } + data[offset] = di; + offset++; + } + k++; + } + } + offset_col[k] = offset; + return offset; +} + /************************* * reference sequence *************************/ @@ -1622,6 +1803,71 @@ tsk_individual_table_equals(const tsk_individual_table_t *self, return ret; } +int +tsk_individual_table_keep_rows(tsk_individual_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *ret_id_map) +{ + int ret = 0; + const tsk_size_t current_num_rows = self->num_rows; + tsk_size_t j, k, remaining_rows; + tsk_id_t pk; + tsk_id_t *id_map = ret_id_map; + tsk_id_t *restrict parents = self->parents; + tsk_size_t *restrict parents_offset = self->parents_offset; + + if (ret_id_map == NULL) { + id_map = tsk_malloc(current_num_rows * sizeof(*id_map)); + if (id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + + keep_mask_to_id_map(current_num_rows, keep, id_map); + + /* See notes in tsk_mutation_table_keep_rows for possibilities + * on making this more flexible */ + for (j = 0; j < current_num_rows; j++) { + if (keep[j]) { + for (k = parents_offset[j]; k < parents_offset[j + 1]; k++) { + pk = parents[k]; + if (pk != TSK_NULL) { + if (pk < 0 || pk >= (tsk_id_t) current_num_rows) { + ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + ; + goto out; + } + if (id_map[pk] == TSK_NULL) { + ret = TSK_ERR_KEEP_ROWS_MAP_TO_DELETED; + goto out; + } + } + } + } + } + + remaining_rows = subset_flags_column(self->flags, current_num_rows, keep); + self->parents_length = subset_remap_ragged_id_column( + self->parents, self->parents_offset, current_num_rows, keep, id_map); + self->location_length = subset_ragged_double_column( + self->location, self->location_offset, current_num_rows, keep); + if (self->metadata_length > 0) { + /* Implementation note: we special case metadata here because + * it'll make the common-case of no metadata a bit faster, and + * to also potentially support more general use of the + * TSK_TABLE_NO_METADATA option. This is done for all the tables + * but only commented on here. */ + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, current_num_rows, keep); + } + self->num_rows = remaining_rows; +out: + if (ret_id_map == NULL) { + tsk_safe_free(id_map); + } + return ret; +} + static int tsk_individual_table_dump( const tsk_individual_table_t *self, kastore_t *store, tsk_flags_t options) @@ -2271,6 +2517,29 @@ tsk_node_table_get_row(const tsk_node_table_t *self, tsk_id_t index, tsk_node_t return ret; } +int +tsk_node_table_keep_rows(tsk_node_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + remaining_rows = subset_flags_column(self->flags, self->num_rows, keep); + subset_double_column(self->time, self->num_rows, keep); + subset_id_column(self->population, self->num_rows, keep); + subset_id_column(self->individual, self->num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_node_table_dump(const tsk_node_table_t *self, kastore_t *store, tsk_flags_t options) { @@ -2940,6 +3209,29 @@ tsk_edge_table_equals( return ret; } +int +tsk_edge_table_keep_rows(tsk_edge_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + remaining_rows = subset_double_column(self->left, self->num_rows, keep); + subset_double_column(self->right, self->num_rows, keep); + subset_id_column(self->parent, self->num_rows, keep); + subset_id_column(self->child, self->num_rows, keep); + if (self->metadata_length > 0) { + tsk_bug_assert(!(self->options & TSK_TABLE_NO_METADATA)); + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_edge_table_dump(const tsk_edge_table_t *self, kastore_t *store, tsk_flags_t options) { @@ -3675,6 +3967,28 @@ tsk_site_table_dump_text(const tsk_site_table_t *self, FILE *out) return ret; } +int +tsk_site_table_keep_rows(tsk_site_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + remaining_rows = subset_double_column(self->position, self->num_rows, keep); + self->ancestral_state_length = subset_ragged_char_column( + self->ancestral_state, self->ancestral_state_offset, self->num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_site_table_dump(const tsk_site_table_t *self, kastore_t *store, tsk_flags_t options) { @@ -4418,6 +4732,65 @@ tsk_mutation_table_dump_text(const tsk_mutation_table_t *self, FILE *out) return ret; } +int +tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *ret_id_map) +{ + int ret = 0; + const tsk_size_t current_num_rows = self->num_rows; + tsk_size_t j, remaining_rows; + tsk_id_t pj; + tsk_id_t *id_map = ret_id_map; + tsk_id_t *restrict parent = self->parent; + + if (ret_id_map == NULL) { + id_map = tsk_malloc(current_num_rows * sizeof(*id_map)); + if (id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + + keep_mask_to_id_map(current_num_rows, keep, id_map); + + /* Note: we could add some options to avoid these checks if we wanted. + * MAP_DELETED_TO_NULL is an obvious one, and I guess it might be + * helpful to also provide NO_REMAP to prevent reference remapping + * entirely. */ + for (j = 0; j < current_num_rows; j++) { + if (keep[j]) { + pj = parent[j]; + if (pj != TSK_NULL) { + if (pj < 0 || pj >= (tsk_id_t) current_num_rows) { + ret = TSK_ERR_MUTATION_OUT_OF_BOUNDS; + goto out; + } + if (id_map[pj] == TSK_NULL) { + ret = TSK_ERR_KEEP_ROWS_MAP_TO_DELETED; + goto out; + } + } + } + } + + remaining_rows = subset_id_column(self->site, current_num_rows, keep); + subset_id_column(self->node, current_num_rows, keep); + subset_remap_id_column(parent, current_num_rows, keep, id_map); + subset_double_column(self->time, current_num_rows, keep); + self->derived_state_length = subset_ragged_char_column( + self->derived_state, self->derived_state_offset, current_num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, current_num_rows, keep); + } + self->num_rows = remaining_rows; +out: + if (ret_id_map == NULL) { + tsk_safe_free(id_map); + } + return ret; +} + static int tsk_mutation_table_dump( const tsk_mutation_table_t *self, kastore_t *store, tsk_flags_t options) @@ -5063,6 +5436,31 @@ tsk_migration_table_equals(const tsk_migration_table_t *self, return ret; } +int +tsk_migration_table_keep_rows(tsk_migration_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + remaining_rows = subset_double_column(self->left, self->num_rows, keep); + subset_double_column(self->right, self->num_rows, keep); + subset_id_column(self->node, self->num_rows, keep); + subset_id_column(self->source, self->num_rows, keep); + subset_id_column(self->dest, self->num_rows, keep); + subset_double_column(self->time, self->num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_migration_table_dump( const tsk_migration_table_t *self, kastore_t *store, tsk_flags_t options) @@ -5632,6 +6030,24 @@ tsk_population_table_equals(const tsk_population_table_t *self, return ret; } +int +tsk_population_table_keep_rows(tsk_population_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = count_true(self->num_rows, keep); + return ret; +} + static int tsk_population_table_dump( const tsk_population_table_t *self, kastore_t *store, tsk_flags_t options) @@ -6244,6 +6660,24 @@ tsk_provenance_table_equals(const tsk_provenance_table_t *self, return ret; } +int +tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const bool *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + self->timestamp_length = subset_ragged_char_column( + self->timestamp, self->timestamp_offset, self->num_rows, keep); + self->record_length = subset_ragged_char_column( + self->record, self->record_offset, self->num_rows, keep); + self->num_rows = count_true(self->num_rows, keep); + + return ret; +} + static int tsk_provenance_table_dump( const tsk_provenance_table_t *self, kastore_t *store, tsk_flags_t options) diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 321a675271..13934a28db 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -1077,6 +1077,49 @@ int tsk_individual_table_extend(tsk_individual_table_t *self, const tsk_individual_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +The values in the ``parents`` column are updated according to this map, so that +reference integrity within the table is maintained. As a consequence of this, +the values in the ``parents`` column for kept rows are bounds-checked and an +error raised if they are not valid. Rows that are deleted are not checked for +parent ID integrity. + +If an attempt is made to delete rows that are referred to by the ``parents`` +column of rows that are retained, an error is raised. + +These error conditions are checked before any alterations to the table are +made. + +@endrst + +@param self A pointer to a tsk_individual_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_individual_table_keep_rows(tsk_individual_table_t *self, const bool *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -1425,6 +1468,36 @@ and is not checked for compatibility with any existing schema on this table. int tsk_node_table_extend(tsk_node_table_t *self, const tsk_node_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. +@endrst + +@param self A pointer to a tsk_node_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_node_table_keep_rows( + tsk_node_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -1735,6 +1808,36 @@ as-is and is not checked for compatibility with any existing schema on this tabl int tsk_edge_table_extend(tsk_edge_table_t *self, const tsk_edge_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. +@endrst + +@param self A pointer to a tsk_edge_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_edge_table_keep_rows( + tsk_edge_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -2069,6 +2172,36 @@ int tsk_migration_table_extend(tsk_migration_table_t *self, const tsk_migration_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. +@endrst + +@param self A pointer to a tsk_migration_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_migration_table_keep_rows(tsk_migration_table_t *self, const bool *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -2377,6 +2510,36 @@ and is not checked for compatibility with any existing schema on this table. int tsk_site_table_extend(tsk_site_table_t *self, const tsk_site_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. +@endrst + +@param self A pointer to a tsk_site_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_site_table_keep_rows( + tsk_site_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -2713,6 +2876,49 @@ int tsk_mutation_table_extend(tsk_mutation_table_t *self, const tsk_mutation_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +The values in the ``parent`` column are updated according to this map, so that +reference integrity within the table is maintained. As a consequence of this, +the values in the ``parent`` column for kept rows are bounds-checked and an +error raised if they are not valid. Rows that are deleted are not checked for +parent ID integrity. + +If an attempt is made to delete rows that are referred to by the ``parent`` +column of rows that are retained, an error is raised. + +These error conditions are checked before any alterations to the table are +made. + +@endrst + +@param self A pointer to a tsk_mutation_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_mutation_table_keep_rows( + tsk_mutation_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -3040,6 +3246,36 @@ int tsk_population_table_extend(tsk_population_table_t *self, const tsk_population_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. +@endrst + +@param self A pointer to a tsk_population_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_population_table_keep_rows(tsk_population_table_t *self, const bool *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -3334,6 +3570,36 @@ int tsk_provenance_table_extend(tsk_provenance_table_t *self, const tsk_provenance_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. +@endrst + +@param self A pointer to a tsk_provenance_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const bool *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. From 876f1920b6687eb7eaf82f3d0768060582b28053 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 3 Feb 2023 16:57:00 +0000 Subject: [PATCH 35/84] Low-level Cpython interface for subset --- python/_tskitmodule.c | 272 +++++++++++++++++++++++++++++++++- python/tests/test_lowlevel.py | 50 ++++++- 2 files changed, 318 insertions(+), 4 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 5ed421c242..17b1832b6a 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -942,13 +942,14 @@ tsk_id_converter(PyObject *py_obj, tsk_id_t *id_out) } static int -int32_array_converter(PyObject *py_obj, PyArrayObject **array_out) +array_converter(int type, PyObject *py_obj, PyArrayObject **array_out) { int ret = 0; PyArrayObject *temp_array; temp_array = (PyArrayObject *) PyArray_FromAny( - py_obj, PyArray_DescrFromType(NPY_INT32), 1, 1, NPY_ARRAY_IN_ARRAY, NULL); + py_obj, PyArray_DescrFromType(type), 1, 1, NPY_ARRAY_IN_ARRAY, NULL); + if (temp_array == NULL) { goto out; } @@ -958,6 +959,67 @@ int32_array_converter(PyObject *py_obj, PyArrayObject **array_out) return ret; } +static int +int32_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + return array_converter(NPY_INT32, py_obj, array_out); +} + +static int +bool_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + /* We are assuming that npy_bool and C99 bool are interchangeable, which + * may not always be true. If this ever crops up in the real world we + * may want to promote this to a module-load time check. + */ + assert(sizeof(npy_bool) == sizeof(bool)); + return array_converter(NPY_BOOL, py_obj, array_out); +} + +/* Note: it doesn't seem to be possible to cast pointers to the actual + * table functions to this type because the first argument must be a + * void *, so the simplest option is to put in a small shim that + * wraps the library function and casts to the correct table type. + */ +typedef int keep_row_func_t( + void *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); + +static PyObject * +table_keep_rows( + PyObject *args, void *table, tsk_size_t num_rows, keep_row_func_t keep_row_func) +{ + + PyObject *ret = NULL; + PyArrayObject *keep = NULL; + PyArrayObject *id_map = NULL; + npy_intp n = (npy_intp) num_rows; + int err; + + if (!PyArg_ParseTuple(args, "O&", &bool_array_converter, &keep)) { + goto out; + } + if (PyArray_DIMS(keep)[0] != n) { + PyErr_SetString(PyExc_ValueError, "keep array must be of length Table.num_rows"); + goto out; + } + id_map = (PyArrayObject *) PyArray_SimpleNew(1, &n, NPY_INT32); + if (id_map == NULL) { + goto out; + } + err = keep_row_func(table, PyArray_DATA(keep), 0, PyArray_DATA(id_map)); + + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) id_map; + id_map = NULL; +out: + Py_XDECREF(keep); + Py_XDECREF(id_map); + return ret; +} + /*=================================================================== * IndividualTable *=================================================================== @@ -1332,6 +1394,28 @@ IndividualTable_extend(IndividualTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +individual_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_individual_table_keep_rows( + (tsk_individual_table_t *) table, keep, options, id_map); +} + +static PyObject * +IndividualTable_keep_rows(IndividualTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + individual_table_keep_rows_generic); +out: + return ret; +} + static PyObject * IndividualTable_get_max_rows_increment(IndividualTable *self, void *closure) { @@ -1578,6 +1662,10 @@ static PyMethodDef IndividualTable_methods[] = { .ml_meth = (PyCFunction) IndividualTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) IndividualTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -1911,6 +1999,27 @@ NodeTable_extend(NodeTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +node_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_node_table_keep_rows((tsk_node_table_t *) table, keep, options, id_map); +} + +static PyObject * +NodeTable_keep_rows(NodeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows( + args, (void *) self->table, self->table->num_rows, node_table_keep_rows_generic); +out: + return ret; +} + static PyObject * NodeTable_get_max_rows_increment(NodeTable *self, void *closure) { @@ -2138,6 +2247,10 @@ static PyMethodDef NodeTable_methods[] = { .ml_meth = (PyCFunction) NodeTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) NodeTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -2482,6 +2595,27 @@ EdgeTable_extend(EdgeTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +edge_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_edge_table_keep_rows((tsk_edge_table_t *) table, keep, options, id_map); +} + +static PyObject * +EdgeTable_keep_rows(EdgeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows( + args, (void *) self->table, self->table->num_rows, edge_table_keep_rows_generic); +out: + return ret; +} + static PyObject * EdgeTable_get_max_rows_increment(EdgeTable *self, void *closure) { @@ -2707,11 +2841,14 @@ static PyMethodDef EdgeTable_methods[] = { .ml_meth = (PyCFunction) EdgeTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, - { .ml_name = "squash", .ml_meth = (PyCFunction) EdgeTable_squash, .ml_flags = METH_NOARGS, .ml_doc = "Squashes sets of edges with adjacent L,R and identical P,C values." }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) EdgeTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -3039,6 +3176,28 @@ MigrationTable_extend(MigrationTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +migration_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_migration_table_keep_rows( + (tsk_migration_table_t *) table, keep, options, id_map); +} + +static PyObject * +MigrationTable_keep_rows(MigrationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + migration_table_keep_rows_generic); +out: + return ret; +} + static PyObject * MigrationTable_get_max_rows_increment(MigrationTable *self, void *closure) { @@ -3296,6 +3455,10 @@ static PyMethodDef MigrationTable_methods[] = { .ml_meth = (PyCFunction) MigrationTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) MigrationTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -3623,6 +3786,27 @@ SiteTable_extend(SiteTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +site_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_site_table_keep_rows((tsk_site_table_t *) table, keep, options, id_map); +} + +static PyObject * +SiteTable_keep_rows(SiteTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows( + args, (void *) self->table, self->table->num_rows, site_table_keep_rows_generic); +out: + return ret; +} + static PyObject * SiteTable_get_max_rows_increment(SiteTable *self, void *closure) { @@ -3837,6 +4021,10 @@ static PyMethodDef SiteTable_methods[] = { .ml_meth = (PyCFunction) SiteTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) SiteTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -4173,6 +4361,28 @@ MutationTable_extend(MutationTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +mutation_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_mutation_table_keep_rows( + (tsk_mutation_table_t *) table, keep, options, id_map); +} + +static PyObject * +MutationTable_keep_rows(MutationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + mutation_table_keep_rows_generic); +out: + return ret; +} + static PyObject * MutationTable_get_max_rows_increment(MutationTable *self, void *closure) { @@ -4432,6 +4642,10 @@ static PyMethodDef MutationTable_methods[] = { .ml_meth = (PyCFunction) MutationTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) MutationTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -4754,6 +4968,28 @@ PopulationTable_extend(PopulationTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +population_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_population_table_keep_rows( + (tsk_population_table_t *) table, keep, options, id_map); +} + +static PyObject * +PopulationTable_keep_rows(PopulationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + population_table_keep_rows_generic); +out: + return ret; +} + static PyObject * PopulationTable_get_max_rows_increment(PopulationTable *self, void *closure) { @@ -4918,6 +5154,10 @@ static PyMethodDef PopulationTable_methods[] = { .ml_meth = (PyCFunction) PopulationTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) PopulationTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -5232,6 +5472,28 @@ ProvenanceTable_extend(ProvenanceTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +provenance_table_keep_rows_generic( + void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_provenance_table_keep_rows( + (tsk_provenance_table_t *) table, keep, options, id_map); +} + +static PyObject * +ProvenanceTable_keep_rows(ProvenanceTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + provenance_table_keep_rows_generic); +out: + return ret; +} + static PyObject * ProvenanceTable_get_max_rows_increment(ProvenanceTable *self, void *closure) { @@ -5385,6 +5647,10 @@ static PyMethodDef ProvenanceTable_methods[] = { .ml_meth = (PyCFunction) ProvenanceTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) ProvenanceTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index e599ad62c9..8bbc201b85 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -763,6 +763,54 @@ def test_table_extend_types( for i, expected_row in enumerate(expected_rows): assert table[len(table_copy) + i] == table_copy[expected_row] + @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) + def test_table_keep_rows_errors(self, table_name, ts_fixture): + table = getattr(ts_fixture.tables, table_name) + n = len(table) + ll_table = table.ll_table + with pytest.raises(ValueError, match="must be of length"): + ll_table.keep_rows(np.ones(n - 1, dtype=bool)) + with pytest.raises(ValueError, match="must be of length"): + ll_table.keep_rows(np.ones(n + 1, dtype=bool)) + with pytest.raises(TypeError, match="Cannot cast"): + ll_table.keep_rows(np.ones(n, dtype=int)) + + @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) + def test_table_keep_rows_all(self, table_name, ts_fixture): + table = getattr(ts_fixture.tables, table_name) + n = len(table) + ll_table = table.ll_table + a = ll_table.keep_rows(np.ones(n, dtype=bool)) + assert ll_table.num_rows == n + assert a.shape == (n,) + assert a.dtype == np.int32 + assert np.all(a == np.arange(n)) + + @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) + def test_table_keep_rows_none(self, table_name, ts_fixture): + table = getattr(ts_fixture.tables, table_name) + n = len(table) + ll_table = table.ll_table + a = ll_table.keep_rows(np.zeros(n, dtype=bool)) + assert ll_table.num_rows == 0 + assert a.shape == (n,) + assert a.dtype == np.int32 + assert np.all(a == -1) + + def test_mutation_table_keep_rows_ref_error(self): + table = _tskit.MutationTable() + table.add_row(site=0, node=0, derived_state="A", parent=2) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_MUTATION_OUT_OF_BOUNDS"): + table.keep_rows([True]) + + def test_individual_table_keep_rows_ref_error(self): + table = _tskit.IndividualTable() + table.add_row(parents=[2]) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS" + ): + table.keep_rows([True]) + @pytest.mark.parametrize( ["table_name", "column_name"], [ From 7927e2c2525fcb2fa950c18a69d4e3ac938941ec Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 10 Feb 2023 06:25:49 +0000 Subject: [PATCH 36/84] High-level interface for keep_rows --- docs/substitutions/table_keep_rows_main.rst | 14 + python/CHANGELOG.rst | 4 +- python/_tskitmodule.c | 4 +- python/tests/test_tables.py | 403 +++++++++++++++++++- python/tskit/tables.py | 78 ++++ 5 files changed, 500 insertions(+), 3 deletions(-) create mode 100644 docs/substitutions/table_keep_rows_main.rst diff --git a/docs/substitutions/table_keep_rows_main.rst b/docs/substitutions/table_keep_rows_main.rst new file mode 100644 index 0000000000..95652527a2 --- /dev/null +++ b/docs/substitutions/table_keep_rows_main.rst @@ -0,0 +1,14 @@ +Updates this table in-place according to the specified boolean +array, and returns the resulting mapping from old to new row IDs. +For each row ``j``, if ``keep[j]`` is True, that row will be +retained in the output; otherwise, the row will be deleted. +Rows are retained in their original ordering. + +The returned ``id_map`` is an array of the same length as +this table before the operation, such that ``id_map[j] = -1`` +(:data:`tskit.NULL`) if row ``j`` was deleted, and ``id_map[j]`` +is the new ID of that row, otherwise. + +.. todo:: + This needs some examples to link to. See + https://github.com/tskit-dev/tskit/issues/2708 diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index a3738ea06b..3a81b2aa1c 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -7,7 +7,9 @@ - Add ``__repr__`` for variants to return a string representation of the raw data without spewing megabytes of text (:user:`chriscrsmith`, :pr:`2695`, :issue:`2694`) - +- Add ``keep_rows`` method to table classes to support efficient in-place + table subsetting (:user:`jeromekelleher`, :pr:`2700`) + -------------------- [0.5.4] - 2023-01-13 -------------------- diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 17b1832b6a..aaf8e25406 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -993,12 +993,14 @@ table_keep_rows( PyArrayObject *keep = NULL; PyArrayObject *id_map = NULL; npy_intp n = (npy_intp) num_rows; + npy_intp array_len; int err; if (!PyArg_ParseTuple(args, "O&", &bool_array_converter, &keep)) { goto out; } - if (PyArray_DIMS(keep)[0] != n) { + array_len = PyArray_DIMS(keep)[0]; + if (array_len != n) { PyErr_SetString(PyExc_ValueError, "keep array must be of length Table.num_rows"); goto out; } diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 919f2309ef..b6d207e2e0 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -625,6 +625,18 @@ def test_append_columns_max_rows(self): else: assert table.max_rows == max(max_rows + 1, table.num_rows) + def test_keep_rows_data(self): + input_data = self.make_input_data(100) + t1 = self.table_class() + t1.append_columns(**input_data) + t2 = t1.copy() + keep = np.ones(len(t1), dtype=bool) + # Only keep even + keep[::2] = 0 + t1.keep_rows(keep) + keep_rows_definition(t2, keep) + assert t1.equals(t2) + def test_str(self): for num_rows in [0, 10]: input_data = self.make_input_data(num_rows) @@ -1729,6 +1741,21 @@ def test_various_not_equals(self): a = tskit.MutationTableRow(**args) assert a == b + def test_keep_rows_data(self): + input_data = self.make_input_data(100) + t1 = self.table_class() + # Set the parent column to -1s for this simple test as + # we need to reason about reference integrity + t1.append_columns(**input_data) + t1.parents = np.full_like(t1.parents, -1) + t2 = t1.copy() + keep = np.ones(len(t1), dtype=bool) + # Only keep even + keep[::2] = 0 + t1.keep_rows(keep) + keep_rows_definition(t2, keep) + assert t1.equals(t2) + class TestNodeTable(*common_tests): @@ -1992,6 +2019,21 @@ def test_packset_derived_state(self): assert np.array_equal(table.derived_state, derived_state) assert np.array_equal(table.derived_state_offset, derived_state_offset) + def test_keep_rows_data(self): + input_data = self.make_input_data(100) + t1 = self.table_class() + # Set the parent column to -1s for this simple test as + # we need to reason about reference integrity + t1.append_columns(**input_data) + t1.parent = np.full_like(t1.parent, -1) + t2 = t1.copy() + keep = np.ones(len(t1), dtype=bool) + # Only keep even + keep[::2] = 0 + t1.keep_rows(keep) + keep_rows_definition(t2, keep) + assert t1.equals(t2) + class TestMigrationTable(*common_tests): columns = [ @@ -5011,3 +5053,362 @@ def test_setitem_metadata(self, ts_fixture, table_name): assert table[0].metadata != table[1].metadata table[0] = table[1] assert table[0] == table[1] + + +def keep_rows_definition(table, keep): + id_map = np.full(len(table), -1, np.int32) + copy = table.copy() + table.clear() + for j, row in enumerate(copy): + if keep[j]: + id_map[j] = len(table) + table.append(row) + return id_map + + +class KeepRowsBaseTest: + # Simple tests assuming that rows aren't self-referential + + def test_keep_all(self, ts_fixture): + table = self.get_table(ts_fixture) + before = table.copy() + table.keep_rows(np.ones(len(table), dtype=bool)) + assert table.equals(before) + + def test_keep_none(self, ts_fixture): + table = self.get_table(ts_fixture) + table.keep_rows(np.zeros(len(table), dtype=bool)) + assert len(table) == 0 + + def check_keep_rows(self, table, keep): + copy = table.copy() + id_map1 = keep_rows_definition(copy, keep) + id_map2 = table.keep_rows(keep) + table.assert_equals(copy) + np.testing.assert_array_equal(id_map1, id_map2) + + def test_keep_even(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.ones(len(table), dtype=bool) + keep[1::2] = 0 + self.check_keep_rows(table, keep) + + def test_keep_odd(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.ones(len(table), dtype=bool) + keep[::2] = 0 + self.check_keep_rows(table, keep) + + def test_keep_first(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.zeros(len(table), dtype=bool) + keep[0] = 1 + self.check_keep_rows(table, keep) + assert len(table) == 1 + + def test_keep_last(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.zeros(len(table), dtype=bool) + keep[-1] = 1 + self.check_keep_rows(table, keep) + assert len(table) == 1 + + @pytest.mark.parametrize("dtype", [np.int32, int, np.float32]) + def test_bad_array_dtype(self, ts_fixture, dtype): + table = self.get_table(ts_fixture) + keep = np.zeros(len(table), dtype=dtype) + with pytest.raises(TypeError, match="Cannot cast array"): + table.keep_rows(keep) + + @pytest.mark.parametrize("truthy", [False, 0, "", None]) + def test_python_falsey_input(self, ts_fixture, truthy): + table = self.get_table(ts_fixture) + keep = [truthy] * len(table) + self.check_keep_rows(table, keep) + assert len(table) == 0 + + @pytest.mark.parametrize("truthy", [True, 1, "string", 1e-6]) + def test_python_truey_input(self, ts_fixture, truthy): + table = self.get_table(ts_fixture) + n = len(table) + keep = [truthy] * len(table) + self.check_keep_rows(table, keep) + assert len(table) == n + + @pytest.mark.parametrize("offset", [-1, 1, 100]) + def test_bad_length(self, ts_fixture, offset): + table = self.get_table(ts_fixture) + keep = [True] * (len(table) + offset) + match_str = f"need:{len(table)}, got:{len(table) + offset}" + with pytest.raises(ValueError, match=match_str): + table.keep_rows(keep) + + @pytest.mark.parametrize("bad_type", [False, 0, None]) + def test_non_list_input(self, ts_fixture, bad_type): + table = self.get_table(ts_fixture) + with pytest.raises(TypeError, match="has no len"): + table.keep_rows(bad_type) + + +class TestNodeTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().nodes + + +class TestEdgeTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().edges + + +class TestSiteTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().sites + + +class TestMigrationTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().migrations + + +class TestPopulationTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().populations + + +class TestProvenanceTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().provenances + + +# Null out the self-referential columns (this is why the tests are structed via +# classes rather than pytest parametrize. + + +class TestIndividualTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + table = ts.dump_tables().individuals + table.parents = np.zeros_like(table.parents) - 1 + return table + + def check_keep_rows(self, table, keep): + copy = table.copy() + id_map1 = keep_rows_definition(copy, keep) + for j, row in enumerate(copy): + parents = [p if p == tskit.NULL else id_map1[p] for p in row.parents] + copy[j] = row.replace(parents=parents) + id_map2 = table.keep_rows(keep) + table.assert_equals(copy) + np.testing.assert_array_equal(id_map1, id_map2) + + def test_delete_unreferenced(self, ts_fixture): + table = ts_fixture.dump_tables().individuals + ref_count = np.zeros(len(table)) + for row in table: + for parent in row.parents: + ref_count[parent] += 1 + self.check_keep_rows(table, ref_count > 0) + + +class TestMutationTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + table = ts.dump_tables().mutations + table.parent = np.zeros_like(table.parent) - 1 + return table + + def check_keep_rows(self, table, keep): + copy = table.copy() + id_map1 = keep_rows_definition(copy, keep) + for j, row in enumerate(copy): + if row.parent != tskit.NULL: + copy[j] = row.replace(parent=id_map1[row.parent]) + id_map2 = table.keep_rows(keep) + table.assert_equals(copy) + np.testing.assert_array_equal(id_map1, id_map2) + + def test_delete_unreferenced(self, ts_fixture): + table = ts_fixture.dump_tables().mutations + parent = table.parent.copy() + parent[parent == tskit.NULL] = len(table) + references = np.bincount(parent) + self.check_keep_rows(table, references[:-1] > 0) + + def test_error_on_bad_ids(self, ts_fixture): + table = ts_fixture.dump_tables().mutations + table.add_row(site=0, node=0, derived_state="A", parent=10000) + before = table.copy() + with pytest.raises(tskit.LibraryError, match="TSK_ERR_MUTATION_OUT_OF_BOUNDS"): + table.keep_rows(np.ones(len(table), dtype=bool)) + table.assert_equals(before) + + +class TestKeepRowsExamples: + """ + Some examples of how to use the keep_rows method in an idiomatic + and efficient way. + + TODO these should be converted into documentation examples when we + write an "examples" section for table editing. + """ + + def test_detach_subtree(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(3).tree_sequence + tables = ts.dump_tables() + tables.edges.keep_rows(tables.edges.child != 3) + + # 2.00┊ 4 ┊ + # ┊ ┃ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tables.tree_sequence() + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 4, 1: 3, 2: 3} + + def test_delete_older_edges(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(3).tree_sequence + tables = ts.dump_tables() + tables.edges.keep_rows(tables.nodes.time[tables.edges.parent] <= 1) + + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 3 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tables.tree_sequence() + assert ts.num_trees == 1 + assert ts.first().parent_dict == {1: 3, 2: 3} + + def test_delete_unreferenced_nodes(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(3).tree_sequence + tables = ts.dump_tables() + edges = tables.edges + nodes = tables.nodes + edges.keep_rows(nodes.time[edges.parent] <= 1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 3 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ref_count = np.bincount(edges.child, minlength=len(nodes)) + ref_count += np.bincount(edges.parent, minlength=len(nodes)) + assert list(ref_count) == [0, 1, 1, 2, 0] + id_map = nodes.keep_rows(ref_count > 0) + assert list(id_map) == [-1, 0, 1, 2, -1] + assert len(nodes) == 3 + # Remap the edges IDs + edges.child = id_map[edges.child] + edges.parent = id_map[edges.parent] + ts = tables.tree_sequence() + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 2, 1: 2} + + def test_mutation_ids_auto_remapped(self): + mutations = tskit.MutationTable() + # Add 5 initial rows with no parents + for j in range(5): + mutations.add_row(site=j, node=j, derived_state=f"{j}") + # Now 5 more in a chain + last = -1 + for j in range(5): + last = mutations.add_row( + site=10 + j, node=10 + j, parent=last, derived_state=f"{j}" + ) + + # ╔══╤════╤════╤════╤═════════════╤══════╤════════╗ + # ║id│site│node│time│derived_state│parent│metadata║ + # ╠══╪════╪════╪════╪═════════════╪══════╪════════╣ + # ║0 │ 0│ 0│ nan│ 0│ -1│ ║ + # ║1 │ 1│ 1│ nan│ 1│ -1│ ║ + # ║2 │ 2│ 2│ nan│ 2│ -1│ ║ + # ║3 │ 3│ 3│ nan│ 3│ -1│ ║ + # ║4 │ 4│ 4│ nan│ 4│ -1│ ║ + # ║5 │ 10│ 10│ nan│ 0│ -1│ ║ + # ║6 │ 11│ 11│ nan│ 1│ 5│ ║ + # ║7 │ 12│ 12│ nan│ 2│ 6│ ║ + # ║8 │ 13│ 13│ nan│ 3│ 7│ ║ + # ║9 │ 14│ 14│ nan│ 4│ 8│ ║ + # ╚══╧════╧════╧════╧═════════════╧══════╧════════╝ + + keep = np.ones(len(mutations), dtype=bool) + keep[:5] = False + mutations.keep_rows(keep) + + # ╔══╤════╤════╤════╤═════════════╤══════╤════════╗ + # ║id│site│node│time│derived_state│parent│metadata║ + # ╠══╪════╪════╪════╪═════════════╪══════╪════════╣ + # ║0 │ 10│ 10│ nan│ 0│ -1│ ║ + # ║1 │ 11│ 11│ nan│ 1│ 0│ ║ + # ║2 │ 12│ 12│ nan│ 2│ 1│ ║ + # ║3 │ 13│ 13│ nan│ 3│ 2│ ║ + # ║4 │ 14│ 14│ nan│ 4│ 3│ ║ + # ╚══╧════╧════╧════╧═════════════╧══════╧════════╝ + assert list(mutations.site) == [10, 11, 12, 13, 14] + assert list(mutations.node) == [10, 11, 12, 13, 14] + assert list(mutations.parent) == [-1, 0, 1, 2, 3] + + def test_individual_ids_auto_remapped(self): + individuals = tskit.IndividualTable() + # Add some rows with missing parents in different forms + individuals.add_row() + individuals.add_row(parents=[-1]) + individuals.add_row(parents=[-1, -1]) + # Now 5 more in a chain + last = -1 + for _ in range(5): + last = individuals.add_row(parents=[last]) + last = individuals.add_row(parents=[last, last]) + + # ╔══╤═════╤════════╤═══════╤════════╗ + # ║id│flags│location│parents│metadata║ + # ╠══╪═════╪════════╪═══════╪════════╣ + # ║0 │ 0│ │ │ ║ + # ║1 │ 0│ │ -1│ ║ + # ║2 │ 0│ │ -1, -1│ ║ + # ║3 │ 0│ │ -1│ ║ + # ║4 │ 0│ │ 3│ ║ + # ║5 │ 0│ │ 4│ ║ + # ║6 │ 0│ │ 5│ ║ + # ║7 │ 0│ │ 6│ ║ + # ║8 │ 0│ │ 7, 7│ ║ + # ╚══╧═════╧════════╧═══════╧════════╝ + + keep = np.ones(len(individuals), dtype=bool) + # Only delete one row + keep[1] = False + individuals.keep_rows(keep) + + # ╔══╤═════╤════════╤═══════╤════════╗ + # ║id│flags│location│parents│metadata║ + # ╠══╪═════╪════════╪═══════╪════════╣ + # ║0 │ 0│ │ │ ║ + # ║1 │ 0│ │ -1, -1│ ║ + # ║2 │ 0│ │ -1│ ║ + # ║3 │ 0│ │ 2│ ║ + # ║4 │ 0│ │ 3│ ║ + # ║5 │ 0│ │ 4│ ║ + # ║6 │ 0│ │ 5│ ║ + # ║7 │ 0│ │ 6, 6│ ║ + # ╚══╧═════╧════════╧═══════╧════════╝ + parents = [list(ind.parents) for ind in individuals] + assert parents == [[], [-1, -1], [-1], [2], [3], [4], [5], [6, 6]] diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 6e8be4eef8..cb2ff7d01f 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -612,6 +612,30 @@ def truncate(self, num_rows): """ return self.ll_table.truncate(num_rows) + def keep_rows(self, keep): + """ + .. include:: substitutions/table_keep_rows_main.rst + + :param array-like keep: The rows to keep as a boolean array. Must + be the same length as the table, and convertible to a numpy + array of dtype bool. + :return: The mapping between old and new row IDs as a numpy + array (dtype int32). + :rtype: numpy.ndarray (dtype=np.int32) + """ + # We do this check here rather than in the C code because calling + # len() on the input will cause a more readable exception to be + # raised than the inscrutable errors we get from numpy when + # converting arguments of the wrong type. + if len(keep) != len(self): + msg = ( + "Argument for keep_rows must be a boolean array of " + "the same length as the table. " + f"(need:{len(self)}, got:{len(keep)})" + ) + raise ValueError(msg) + return self.ll_table.keep_rows(keep) + # Pickle support def __getstate__(self): return self.asdict() @@ -1023,6 +1047,33 @@ def packset_parents(self, parents): d["parents_offset"] = offset self.set_columns(**d) + def keep_rows(self, keep): + """ + .. include:: substitutions/table_keep_rows_main.rst + + The values in the ``parents`` column are updated according to this + map, so that reference integrity within the table is maintained. + As a consequence of this, the values in the ``parents`` column + for kept rows are bounds-checked and an error raised if they + are not valid. Rows that are deleted are not checked for + parent ID integrity. + + If an attempt is made to delete rows that are referred to by + the ``parents`` column of rows that are retained, an error + is raised. + + These error conditions are checked before any alterations to + the table are made. + + :param array-like keep: The rows to keep as a boolean array. Must + be the same length as the table, and convertible to a numpy + array of dtype bool. + :return: The mapping between old and new row IDs as a numpy + array (dtype int32). + :rtype: numpy.ndarray (dtype=np.int32) + """ + return super().keep_rows(keep) + class NodeTable(MetadataTable): """ @@ -2111,6 +2162,33 @@ def packset_derived_state(self, derived_states): d["derived_state_offset"] = offset self.set_columns(**d) + def keep_rows(self, keep): + """ + .. include:: substitutions/table_keep_rows_main.rst + + The values in the ``parent`` column are updated according to this + map, so that reference integrity within the table is maintained. + As a consequence of this, the values in the ``parent`` column + for kept rows are bounds-checked and an error raised if they + are not valid. Rows that are deleted are not checked for + parent ID integrity. + + If an attempt is made to delete rows that are referred to by + the ``parent`` column of rows that are retained, an error + is raised. + + These error conditions are checked before any alterations to + the table are made. + + :param array-like keep: The rows to keep as a boolean array. Must + be the same length as the table, and convertible to a numpy + array of dtype bool. + :return: The mapping between old and new row IDs as a numpy + array (dtype int32). + :rtype: numpy.ndarray (dtype=np.int32) + """ + return super().keep_rows(keep) + class PopulationTable(MetadataTable): """ From 6c8c77940cf51e673d3c6eb3c7dbf359d259dca9 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 22 Feb 2023 20:02:41 +0000 Subject: [PATCH 37/84] Add tsk_bool_t and change keep_rows signature --- c/tests/test_tables.c | 22 ++++++------ c/tskit/core.h | 9 +++++ c/tskit/tables.c | 35 +++++++++---------- c/tskit/tables.h | 78 ++++++++++++++++++++++++++++++++++++------- docs/_config.yml | 1 + docs/c-api.rst | 1 + python/_tskitmodule.c | 23 +++++-------- 7 files changed, 115 insertions(+), 54 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 65ff5916ad..6de6675ff6 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -1499,7 +1499,7 @@ test_node_table_keep_rows(void) tsk_size_t j; tsk_node_table_t source, t1, t2; tsk_node_t row; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t id_map[3]; const char *metadata = "ABC"; tsk_id_t indexes[] = { 0, 1, 2 }; @@ -2142,7 +2142,7 @@ test_edge_table_keep_rows(void) tsk_size_t j; tsk_edge_table_t source, t1, t2; tsk_edge_t row; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t id_map[3]; const char *metadata = "ABC"; tsk_id_t indexes[] = { 0, 1, 2 }; @@ -2242,7 +2242,7 @@ test_edge_table_keep_rows_no_metadata(void) tsk_size_t j; tsk_edge_table_t source, t1, t2; tsk_edge_t row; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t id_map[3]; tsk_id_t indexes[] = { 0, 1, 2 }; @@ -3265,7 +3265,7 @@ test_site_table_keep_rows(void) tsk_site_t row; const char *ancestral_state = "XYZ"; const char *metadata = "ABC"; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t id_map[3]; tsk_id_t indexes[] = { 0, 1, 2 }; @@ -4051,7 +4051,7 @@ test_mutation_table_keep_rows(void) tsk_mutation_t row; const char *derived_state = "XYZ"; const char *metadata = "ABC"; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t id_map[3]; tsk_id_t indexes[] = { 0, 1, 2 }; @@ -4154,7 +4154,7 @@ test_mutation_table_keep_rows_parent_references(void) int ret; tsk_id_t ret_id; tsk_mutation_table_t source, t; - bool keep[4] = { 1, 1, 1, 1 }; + tsk_bool_t keep[4] = { 1, 1, 1, 1 }; tsk_id_t id_map[4]; ret = tsk_mutation_table_init(&source, 0); @@ -4844,7 +4844,7 @@ test_migration_table_keep_rows(void) tsk_migration_table_t source, t1, t2; tsk_migration_t row; const char *metadata = "ABC"; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t id_map[3]; tsk_id_t indexes[] = { 0, 1, 2 }; @@ -5671,7 +5671,7 @@ test_individual_table_keep_rows(void) double location[] = { 0, 1, 2 }; tsk_id_t parents[] = { -1, 1, -1 }; const char *metadata = "ABC"; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t indexes[] = { 0, 1, 2 }; tsk_id_t id_map[3]; tsk_individual_table_t source, t1, t2; @@ -5776,7 +5776,7 @@ test_individual_table_keep_rows_parent_references(void) int ret; tsk_id_t ret_id; tsk_individual_table_t source, t; - bool keep[] = { 1, 1, 1, 1 }; + tsk_bool_t keep[] = { 1, 1, 1, 1 }; tsk_id_t parents[] = { -1, 1, 2 }; tsk_id_t id_map[4]; @@ -6243,7 +6243,7 @@ test_population_table_keep_rows(void) tsk_population_table_t source, t1, t2; tsk_population_t row; const char *metadata = "ABC"; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t id_map[3]; tsk_id_t indexes[] = { 0, 1, 2 }; @@ -6779,7 +6779,7 @@ test_provenance_table_keep_rows(void) tsk_provenance_t row; const char *timestamp = "XYZ"; const char *record = "ABC"; - bool keep[3] = { 1, 1, 1 }; + tsk_bool_t keep[3] = { 1, 1, 1 }; tsk_id_t indexes[] = { 0, 1, 2 }; tsk_id_t id_map[3]; diff --git a/c/tskit/core.h b/c/tskit/core.h index 7810a8d048..20ca5881da 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -123,6 +123,15 @@ specify options to API functions. typedef uint32_t tsk_flags_t; #define TSK_FLAGS_STORAGE_TYPE KAS_UINT32 +/** +@brief Boolean type. + +@rst +Fixed-size (1 byte) boolean values. +@endrst +*/ +typedef uint8_t tsk_bool_t; + // clang-format off /** @defgroup API_VERSION_GROUP API version macros. diff --git a/c/tskit/tables.c b/c/tskit/tables.c index c0ce82bd39..8eea85f5ad 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -735,7 +735,7 @@ write_metadata_schema_header( /* Utilities for in-place subsetting columns */ static tsk_size_t -count_true(tsk_size_t num_rows, const bool *restrict keep) +count_true(tsk_size_t num_rows, const tsk_bool_t *restrict keep) { tsk_size_t j; tsk_size_t count = 0; @@ -750,7 +750,7 @@ count_true(tsk_size_t num_rows, const bool *restrict keep) static void keep_mask_to_id_map( - tsk_size_t num_rows, const bool *restrict keep, tsk_id_t *restrict id_map) + tsk_size_t num_rows, const tsk_bool_t *restrict keep, tsk_id_t *restrict id_map) { tsk_size_t j; tsk_id_t next_id = 0; @@ -766,7 +766,7 @@ keep_mask_to_id_map( static tsk_size_t subset_remap_id_column(tsk_id_t *restrict column, tsk_size_t num_rows, - const bool *restrict keep, const tsk_id_t *restrict id_map) + const tsk_bool_t *restrict keep, const tsk_id_t *restrict id_map) { tsk_size_t j, k; tsk_id_t value; @@ -792,7 +792,7 @@ subset_remap_id_column(tsk_id_t *restrict column, tsk_size_t num_rows, static tsk_size_t subset_id_column( - tsk_id_t *restrict column, tsk_size_t num_rows, const bool *restrict keep) + tsk_id_t *restrict column, tsk_size_t num_rows, const tsk_bool_t *restrict keep) { tsk_size_t j, k; @@ -808,7 +808,7 @@ subset_id_column( static tsk_size_t subset_flags_column( - tsk_flags_t *restrict column, tsk_size_t num_rows, const bool *restrict keep) + tsk_flags_t *restrict column, tsk_size_t num_rows, const tsk_bool_t *restrict keep) { tsk_size_t j, k; @@ -824,7 +824,7 @@ subset_flags_column( static tsk_size_t subset_double_column( - double *restrict column, tsk_size_t num_rows, const bool *restrict keep) + double *restrict column, tsk_size_t num_rows, const tsk_bool_t *restrict keep) { tsk_size_t j, k; @@ -840,7 +840,7 @@ subset_double_column( static tsk_size_t subset_ragged_char_column(char *restrict data, tsk_size_t *restrict offset_col, - tsk_size_t num_rows, const bool *restrict keep) + tsk_size_t num_rows, const tsk_bool_t *restrict keep) { tsk_size_t j, k, i, offset; @@ -864,7 +864,7 @@ subset_ragged_char_column(char *restrict data, tsk_size_t *restrict offset_col, static tsk_size_t subset_ragged_double_column(double *restrict data, tsk_size_t *restrict offset_col, - tsk_size_t num_rows, const bool *restrict keep) + tsk_size_t num_rows, const tsk_bool_t *restrict keep) { tsk_size_t j, k, i, offset; @@ -888,7 +888,8 @@ subset_ragged_double_column(double *restrict data, tsk_size_t *restrict offset_c static tsk_size_t subset_remap_ragged_id_column(tsk_id_t *restrict data, tsk_size_t *restrict offset_col, - tsk_size_t num_rows, const bool *restrict keep, const tsk_id_t *restrict id_map) + tsk_size_t num_rows, const tsk_bool_t *restrict keep, + const tsk_id_t *restrict id_map) { tsk_size_t j, k, i, offset; tsk_id_t di; @@ -1804,7 +1805,7 @@ tsk_individual_table_equals(const tsk_individual_table_t *self, } int -tsk_individual_table_keep_rows(tsk_individual_table_t *self, const bool *keep, +tsk_individual_table_keep_rows(tsk_individual_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *ret_id_map) { int ret = 0; @@ -2518,7 +2519,7 @@ tsk_node_table_get_row(const tsk_node_table_t *self, tsk_id_t index, tsk_node_t } int -tsk_node_table_keep_rows(tsk_node_table_t *self, const bool *keep, +tsk_node_table_keep_rows(tsk_node_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) { int ret = 0; @@ -3210,7 +3211,7 @@ tsk_edge_table_equals( } int -tsk_edge_table_keep_rows(tsk_edge_table_t *self, const bool *keep, +tsk_edge_table_keep_rows(tsk_edge_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) { int ret = 0; @@ -3968,7 +3969,7 @@ tsk_site_table_dump_text(const tsk_site_table_t *self, FILE *out) } int -tsk_site_table_keep_rows(tsk_site_table_t *self, const bool *keep, +tsk_site_table_keep_rows(tsk_site_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) { int ret = 0; @@ -4733,7 +4734,7 @@ tsk_mutation_table_dump_text(const tsk_mutation_table_t *self, FILE *out) } int -tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const bool *keep, +tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *ret_id_map) { int ret = 0; @@ -5437,7 +5438,7 @@ tsk_migration_table_equals(const tsk_migration_table_t *self, } int -tsk_migration_table_keep_rows(tsk_migration_table_t *self, const bool *keep, +tsk_migration_table_keep_rows(tsk_migration_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) { int ret = 0; @@ -6031,7 +6032,7 @@ tsk_population_table_equals(const tsk_population_table_t *self, } int -tsk_population_table_keep_rows(tsk_population_table_t *self, const bool *keep, +tsk_population_table_keep_rows(tsk_population_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) { int ret = 0; @@ -6661,7 +6662,7 @@ tsk_provenance_table_equals(const tsk_provenance_table_t *self, } int -tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const bool *keep, +tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const tsk_bool_t *keep, tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) { int ret = 0; diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 13934a28db..38f3096c9d 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -1106,6 +1106,12 @@ column of rows that are retained, an error is raised. These error conditions are checked before any alterations to the table are made. +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_individual_table_t object. @@ -1117,7 +1123,7 @@ made. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_individual_table_keep_rows(tsk_individual_table_t *self, const bool *keep, +int tsk_individual_table_keep_rows(tsk_individual_table_t *self, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map); /** @@ -1484,6 +1490,13 @@ the mapping between IDs before and after row deletion. For row ``j``, ``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or :c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_node_table_t object. @@ -1495,8 +1508,8 @@ array of at least ``num_rows`` :c:type:`tsk_id_t` values. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_node_table_keep_rows( - tsk_node_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); +int tsk_node_table_keep_rows(tsk_node_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); /** @brief Returns true if the data in the specified table is identical to the data @@ -1824,6 +1837,13 @@ the mapping between IDs before and after row deletion. For row ``j``, ``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or :c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_edge_table_t object. @@ -1835,8 +1855,8 @@ array of at least ``num_rows`` :c:type:`tsk_id_t` values. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_edge_table_keep_rows( - tsk_edge_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); +int tsk_edge_table_keep_rows(tsk_edge_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); /** @brief Returns true if the data in the specified table is identical to the data @@ -2188,6 +2208,13 @@ the mapping between IDs before and after row deletion. For row ``j``, ``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or :c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_migration_table_t object. @@ -2199,7 +2226,7 @@ array of at least ``num_rows`` :c:type:`tsk_id_t` values. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_migration_table_keep_rows(tsk_migration_table_t *self, const bool *keep, +int tsk_migration_table_keep_rows(tsk_migration_table_t *self, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map); /** @@ -2526,6 +2553,13 @@ the mapping between IDs before and after row deletion. For row ``j``, ``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or :c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_site_table_t object. @@ -2537,8 +2571,8 @@ array of at least ``num_rows`` :c:type:`tsk_id_t` values. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_site_table_keep_rows( - tsk_site_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); +int tsk_site_table_keep_rows(tsk_site_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); /** @brief Returns true if the data in the specified table is identical to the data @@ -2905,6 +2939,12 @@ column of rows that are retained, an error is raised. These error conditions are checked before any alterations to the table are made. +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_mutation_table_t object. @@ -2916,8 +2956,8 @@ made. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_mutation_table_keep_rows( - tsk_mutation_table_t *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); +int tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); /** @brief Returns true if the data in the specified table is identical to the data @@ -3262,6 +3302,13 @@ the mapping between IDs before and after row deletion. For row ``j``, ``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or :c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_population_table_t object. @@ -3273,7 +3320,7 @@ array of at least ``num_rows`` :c:type:`tsk_id_t` values. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_population_table_keep_rows(tsk_population_table_t *self, const bool *keep, +int tsk_population_table_keep_rows(tsk_population_table_t *self, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map); /** @@ -3586,6 +3633,13 @@ the mapping between IDs before and after row deletion. For row ``j``, ``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or :c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + @endrst @param self A pointer to a tsk_provenance_table_t object. @@ -3597,7 +3651,7 @@ array of at least ``num_rows`` :c:type:`tsk_id_t` values. and old IDs. If NULL, this will be ignored. @return Return 0 on success or a negative value on failure. */ -int tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const bool *keep, +int tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map); /** diff --git a/docs/_config.yml b/docs/_config.yml index 6906a678cd..e9ced63c29 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -67,6 +67,7 @@ sphinx: # Note we have to use the regex version here because of # https://github.com/sphinx-doc/sphinx/issues/9748 nitpick_ignore_regex: [ + ["c:identifier", "uint8_t"], ["c:identifier", "int32_t"], ["c:identifier", "uint32_t"], ["c:identifier", "uint64_t"], diff --git a/docs/c-api.rst b/docs/c-api.rst index 33246cf6cd..bd8233ed6e 100644 --- a/docs/c-api.rst +++ b/docs/c-api.rst @@ -233,6 +233,7 @@ Basic Types .. doxygentypedef:: tsk_id_t .. doxygentypedef:: tsk_size_t .. doxygentypedef:: tsk_flags_t +.. doxygentypedef:: tsk_bool_t ************** Common options diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index aaf8e25406..2b379ff2b1 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -968,11 +968,6 @@ int32_array_converter(PyObject *py_obj, PyArrayObject **array_out) static int bool_array_converter(PyObject *py_obj, PyArrayObject **array_out) { - /* We are assuming that npy_bool and C99 bool are interchangeable, which - * may not always be true. If this ever crops up in the real world we - * may want to promote this to a module-load time check. - */ - assert(sizeof(npy_bool) == sizeof(bool)); return array_converter(NPY_BOOL, py_obj, array_out); } @@ -982,7 +977,7 @@ bool_array_converter(PyObject *py_obj, PyArrayObject **array_out) * wraps the library function and casts to the correct table type. */ typedef int keep_row_func_t( - void *self, const bool *keep, tsk_flags_t options, tsk_id_t *id_map); + void *self, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map); static PyObject * table_keep_rows( @@ -1398,7 +1393,7 @@ IndividualTable_extend(IndividualTable *self, PyObject *args, PyObject *kwds) static int individual_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_individual_table_keep_rows( (tsk_individual_table_t *) table, keep, options, id_map); @@ -2003,7 +1998,7 @@ NodeTable_extend(NodeTable *self, PyObject *args, PyObject *kwds) static int node_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_node_table_keep_rows((tsk_node_table_t *) table, keep, options, id_map); } @@ -2599,7 +2594,7 @@ EdgeTable_extend(EdgeTable *self, PyObject *args, PyObject *kwds) static int edge_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_edge_table_keep_rows((tsk_edge_table_t *) table, keep, options, id_map); } @@ -3180,7 +3175,7 @@ MigrationTable_extend(MigrationTable *self, PyObject *args, PyObject *kwds) static int migration_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_migration_table_keep_rows( (tsk_migration_table_t *) table, keep, options, id_map); @@ -3790,7 +3785,7 @@ SiteTable_extend(SiteTable *self, PyObject *args, PyObject *kwds) static int site_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_site_table_keep_rows((tsk_site_table_t *) table, keep, options, id_map); } @@ -4365,7 +4360,7 @@ MutationTable_extend(MutationTable *self, PyObject *args, PyObject *kwds) static int mutation_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_mutation_table_keep_rows( (tsk_mutation_table_t *) table, keep, options, id_map); @@ -4972,7 +4967,7 @@ PopulationTable_extend(PopulationTable *self, PyObject *args, PyObject *kwds) static int population_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_population_table_keep_rows( (tsk_population_table_t *) table, keep, options, id_map); @@ -5476,7 +5471,7 @@ ProvenanceTable_extend(ProvenanceTable *self, PyObject *args, PyObject *kwds) static int provenance_table_keep_rows_generic( - void *table, const bool *keep, tsk_flags_t options, tsk_id_t *id_map) + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) { return tsk_provenance_table_keep_rows( (tsk_provenance_table_t *) table, keep, options, id_map); From cd0c0ac61722b2f1306edba4b1d253d95a4702a4 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 31 Oct 2022 12:24:48 -0700 Subject: [PATCH 38/84] Use recursive updates for weights, add weighted quantile method --- python/tests/test_coaltime_distribution.py | 378 +++++++++++++---- python/tskit/stats.py | 467 +++++++++++++++------ 2 files changed, 631 insertions(+), 214 deletions(-) diff --git a/python/tests/test_coaltime_distribution.py b/python/tests/test_coaltime_distribution.py index 30ade948a1..715677d99e 100644 --- a/python/tests/test_coaltime_distribution.py +++ b/python/tests/test_coaltime_distribution.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -25,6 +25,7 @@ """ import msprime import numpy as np +import pytest import tests import tskit @@ -181,7 +182,7 @@ def ts_two_trees_ten_leaves(self): @tests.cached_example def ts_many_edge_diffs(self): ts = msprime.sim_ancestry( - samples=75, + samples=80, ploidy=1, sequence_length=4, recombination_rate=10, @@ -216,31 +217,31 @@ def test_time(self): t = np.array([0, 1, 5, 8, 29]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array([[0, 1, 1, 1, 1]]).T distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array([[0, 1, 2, 3, 4]]).T distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array([[0, 0.25, 0.50, 0.75, 1]]).T distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq) + np.testing.assert_allclose(q, tq) class TestPairWeightedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -277,13 +278,13 @@ def test_time(self): t = np.array([0, 1, 5, 8, 29]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array( @@ -297,7 +298,7 @@ def test_weights(self): ) distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array( @@ -311,7 +312,7 @@ def test_cum_weights(self): ) distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array( @@ -325,7 +326,7 @@ def test_quantile(self): ) distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq) + np.testing.assert_allclose(q, tq) class TestTrioFirstWeightedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -376,13 +377,13 @@ def test_time(self): t = np.array([0.0, 1.0, 2.0, 2.0, 6.0, 8.00]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array( @@ -397,7 +398,7 @@ def test_weights(self): ) distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array( @@ -412,7 +413,7 @@ def test_cum_weights(self): ) distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array( @@ -429,7 +430,7 @@ def test_quantile(self): q /= q[-1, :] distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq[:, :-1]) and np.all(np.isnan(tq[:, -1])) + np.testing.assert_allclose(q, tq[:, :-1]) and np.all(np.isnan(tq[:, -1])) class TestSingleBlockCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -461,31 +462,32 @@ def test_time(self): t = np.array([0.0, 0.54, 0.59, 0.73, 1.74]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array([[0, 1, 2, 1, 2]]).T distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array([[0, 1, 3, 4, 6]]).T distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) and np.allclose(c, tc) + np.testing.assert_allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array([[0.0, 1 / 6, 3 / 6, 4 / 6, 1.0]]).T distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq) + np.testing.assert_allclose(q, tq) class TestWindowedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -523,7 +525,8 @@ def test_time(self): distr = self.coalescence_time_distribution() tt1 = distr.tables[0].time tt2 = distr.tables[1].time - assert np.allclose(t1, tt1) and np.allclose(t2, tt2) + np.testing.assert_allclose(t1, tt1) + np.testing.assert_allclose(t2, tt2) def test_block(self): b1 = np.array([0, 0, 0, 0]) @@ -531,7 +534,8 @@ def test_block(self): distr = self.coalescence_time_distribution() tb1 = distr.tables[0].block tb2 = distr.tables[1].block - assert np.allclose(b1, tb1) and np.allclose(b2, tb2) + np.testing.assert_allclose(b1, tb1) + np.testing.assert_allclose(b2, tb2) def test_weights(self): w1 = np.array([[0, 1, 1, 1]]).T @@ -539,7 +543,8 @@ def test_weights(self): distr = self.coalescence_time_distribution() tw1 = distr.tables[0].weights tw2 = distr.tables[1].weights - assert np.allclose(w1, tw1) and np.allclose(w2, tw2) + np.testing.assert_allclose(w1, tw1) + np.testing.assert_allclose(w2, tw2) def test_cum_weights(self): c1 = np.array([[0, 1, 2, 3]]).T @@ -547,7 +552,8 @@ def test_cum_weights(self): distr = self.coalescence_time_distribution() tc1 = distr.tables[0].cum_weights tc2 = distr.tables[1].cum_weights - assert np.allclose(c1, tc1) and np.allclose(c2, tc2) + np.testing.assert_allclose(c1, tc1) + np.testing.assert_allclose(c2, tc2) def test_quantile(self): e1 = np.array([[0.0, 1 / 3, 2 / 3, 1.0]]).T @@ -555,7 +561,8 @@ def test_quantile(self): distr = self.coalescence_time_distribution() te1 = distr.tables[0].quantile te2 = distr.tables[1].quantile - assert np.allclose(e1, te1) and np.allclose(e2, te2) + np.testing.assert_allclose(e1, te1) + np.testing.assert_allclose(e2, te2) class TestCoalescenceTimeDistributionPointMethods(TestCoalescenceTimeDistribution): @@ -595,7 +602,7 @@ def test_ecdf(self): [0.0, 0.25, et[1], 0.57, et[2], 0.65, et[3], 1.00, et[4], 2.00], ) te = distr.ecdf(t) - assert np.allclose(e, te) + np.testing.assert_allclose(e, te) def test_num_coalesced(self): c = np.array([0, 0, 1, 1, 3, 3, 4, 4, 6, 6]).reshape(1, 10, 1) @@ -605,7 +612,7 @@ def test_num_coalesced(self): [0.0, 0.25, et[1], 0.57, et[2], 0.65, et[3], 1.00, et[4], 2.00], ) tc = distr.num_coalesced(t) - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_num_uncoalesced(self): u = np.array([6, 6, 5, 5, 3, 3, 2, 2, 0, 0]).reshape(1, 10, 1) @@ -615,7 +622,28 @@ def test_num_uncoalesced(self): [0.0, 0.25, et[1], 0.57, et[2], 0.65, et[3], 1.00, et[4], 2.00], ) tu = distr.num_uncoalesced(t) - assert np.allclose(u, tu) + np.testing.assert_allclose(u, tu) + + def test_interpolated_quantile(self): + x = np.array( + [ + 0.54, + 0.558, + 0.576, + 0.5993, + 0.6413, + 0.6833, + 0.7253, + 0.9609, + 1.2206, + 1.4803, + 1.74, + ] + ).reshape(1, 11, 1) + distr = self.coalescence_time_distribution() + q = np.linspace(0, 1, 11) + qx = distr.quantile(q).round(4) + np.testing.assert_allclose(x, qx) class TestCoalescenceTimeDistributionIntervalMethods(TestCoalescenceTimeDistribution): @@ -667,7 +695,7 @@ def test_coalescence_probability_in_intervals(self): et = distr.tables[0].time t = np.array([0.00, 0.55, et[3], 2.00]) tp = distr.coalescence_probability_in_intervals(t) - assert np.allclose(p, tp) + np.testing.assert_allclose(p, tp) def test_coalescence_probability_in_intervals_oor(self): distr = self.coalescence_time_distribution() @@ -681,7 +709,7 @@ def test_coalescence_rate_in_intervals(self): et = distr.tables[0].time t = np.array([0.00, 0.55, et[3], 2.00]) tc = distr.coalescence_rate_in_intervals(t) - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc, atol=1e-6) def test_coalescence_rate_in_intervals_oor(self): distr = self.coalescence_time_distribution() @@ -694,7 +722,7 @@ def test_mean(self): distr = self.coalescence_time_distribution() et = distr.tables[0].time tm = distr.mean(et[2]) - assert np.allclose(m, tm) + np.testing.assert_allclose(m, tm) def test_mean_oor(self): distr = self.coalescence_time_distribution() @@ -754,7 +782,8 @@ def test_cum_weights(self): boot_distr = self.coalescence_time_distribution_boot() tw1 = boot_distr.tables[0].cum_weights tw2 = boot_distr.tables[1].cum_weights - assert np.allclose(w1, tw1) and np.allclose(w2, tw2) + np.testing.assert_allclose(w1, tw1) + np.testing.assert_allclose(w2, tw2) def test_ecdf(self): e = np.array( @@ -766,20 +795,20 @@ def test_ecdf(self): boot_distr = self.coalescence_time_distribution_boot() t = np.array([0.54, 0.55, 0.59, 0.60, 0.73, 0.74, 1.74]) te = boot_distr.ecdf(t) - assert np.allclose(e, te) + np.testing.assert_allclose(e, te) def test_mean(self): m = np.array([[1.02, 0.9566667]]) boot_distr = self.coalescence_time_distribution_boot() tm = boot_distr.mean() - assert np.allclose(m, tm) + np.testing.assert_allclose(m, tm) def test_boot_of_boot_equivalence(self): boot_distr = self.coalescence_time_distribution_boot() reboot_distr = next(boot_distr.block_bootstrap(1, 3)) cw1 = boot_distr.tables[1].cum_weights cw2 = reboot_distr.tables[1].cum_weights - assert np.allclose(cw1, cw2) + np.testing.assert_allclose(cw1, cw2) class TestCoalescenceTimeDistributionEmpty(TestCoalescenceTimeDistribution): @@ -790,11 +819,16 @@ class TestCoalescenceTimeDistributionEmpty(TestCoalescenceTimeDistribution): def coalescence_time_distribution(self): ts = self.ts_two_trees_four_leaves() - def null_weight(node, tree, sample_sets): - return np.array([0, 0]) + def null_weight_init(node, sample_sets): + blank = np.array([[0, 0]], dtype=np.float64) + return (blank,) + + def null_weight_update(blank): + blank = np.array([[0, 0]], dtype=np.float64) + return blank, (blank,) distr = ts.coalescence_time_distribution( - weight_func=null_weight, + weight_func=(null_weight_init, null_weight_update), span_normalise=False, ) return distr @@ -834,6 +868,18 @@ def test_coalescence_rate_in_intervals(self): tc = distr.coalescence_rate_in_intervals(t) assert np.all(np.isnan(tc)) + def test_quantile(self): + distr = self.coalescence_time_distribution() + t = np.array([0.0, 0.5, 1.0]) + tq = distr.quantile(t) + assert np.all(np.isnan(tq)) + + def test_resample(self): + distr = self.coalescence_time_distribution() + boot_distr = next(distr.block_bootstrap(1, 3)) + assert np.all(boot_distr.tables[0].cum_weights == 0) + assert np.all(np.isnan(boot_distr.tables[0].quantile)) + class TestCoalescenceTimeDistributionNullWeight(TestCoalescenceTimeDistribution): """ @@ -844,11 +890,16 @@ class TestCoalescenceTimeDistributionNullWeight(TestCoalescenceTimeDistribution) def coalescence_time_distribution(self): ts = self.ts_two_trees_four_leaves() - def half_empty(node, tree, sample_sets): - return np.array([1, 0]) + def half_empty_init(node, sample_sets): + blank = np.array([[1, 0]], dtype=np.float64) + return (blank,) + + def half_empty_update(blank): + blank = np.array([[1, 0]], dtype=np.float64) + return blank, (blank,) distr = ts.coalescence_time_distribution( - weight_func=half_empty, + weight_func=(half_empty_init, half_empty_update), span_normalise=False, ) return distr @@ -888,6 +939,20 @@ def test_coalescence_rate_in_intervals(self): tr = distr.coalescence_rate_in_intervals(t) assert np.all(np.isnan(tr[1, :])) and np.all(~np.isnan(tr[0, :])) + def test_quantile(self): + distr = self.coalescence_time_distribution() + t = np.array([0.0, 0.5, 1.0]) + tq = distr.quantile(t) + assert np.all(np.isnan(tq[1, :])) and np.all(~np.isnan(tq[0, :])) + + def test_resample(self): + distr = self.coalescence_time_distribution() + boot_distr = next(distr.block_bootstrap(1, 3)) + assert np.all(boot_distr.tables[0].cum_weights[:, 1] == 0) + assert np.all(np.isnan(boot_distr.tables[0].quantile[:, 1])) + assert np.any(boot_distr.tables[0].cum_weights[:, 0] > 0) + assert np.all(~np.isnan(boot_distr.tables[0].quantile[:, 0])) + class TestCoalescenceTimeDistributionTableResize(TestCoalescenceTimeDistribution): """ @@ -921,12 +986,18 @@ def coalescence_time_distribution(self): ts = self.ts_eight_trees_two_leaves() bk = [t.interval.left for t in ts.trees()][::4] + [ts.sequence_length] - def count_root(node, tree, sample_sets): - weight = int(node == tree.get_root()) - return np.array([weight]) + def count_root_init(node, sample_sets): + all_samples = [i for s in sample_sets for i in s] + state = np.array([[node == i for i in all_samples]], dtype=np.float64) + return (state,) + + def count_root_update(child_state): + state = np.sum(child_state, axis=0, keepdims=True) + is_root = np.array([[np.all(state > 0)]], dtype=np.float64) + return is_root, (state,) distr = ts.coalescence_time_distribution( - weight_func=count_root, + weight_func=(count_root_init, count_root_update), window_breaks=np.array(bk), blocks_per_window=2, span_normalise=False, @@ -936,12 +1007,12 @@ def count_root(node, tree, sample_sets): def test_blocks_per_window(self): distr = self.coalescence_time_distribution() bpw = np.array([i.num_blocks for i in distr.tables]) - assert np.allclose(bpw, 2) + np.testing.assert_allclose(bpw, 2) def test_trees_per_window(self): distr = self.coalescence_time_distribution() tpw = np.array([np.sum(distr.tables[i].weights) for i in range(2)]) - assert np.allclose(tpw, 4) + np.testing.assert_allclose(tpw, 4) def test_trees_per_block(self): distr = self.coalescence_time_distribution() @@ -949,7 +1020,80 @@ def test_trees_per_block(self): for table in distr.tables: for block in range(2): tpb += [np.sum(table.weights[table.block == block])] - assert np.allclose(tpb, 2) + np.testing.assert_allclose(tpb, 2) + + +class TestCoalescenceTimeDistributionBlockedVsUnblocked( + TestCoalescenceTimeDistribution +): + """ + Test that methods give the same result regardless of how trees are blocked. + """ + + def coalescence_time_distribution(self, num_blocks=1): + ts = self.ts_many_edge_diffs() + sample_sets = [list(range(10)), list(range(20, 40)), list(range(70, 80))] + distr = ts.coalescence_time_distribution( + sample_sets=sample_sets, + weight_func="pair_coalescence_events", + blocks_per_window=num_blocks, + span_normalise=True, + ) + return distr + + def test_ecdf(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose(distr_noblock.ecdf(t), distr_block.ecdf(t)) + + def test_num_coalesced(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.num_coalesced(t), distr_block.num_coalesced(t) + ) + + def test_num_uncoalesced(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.num_uncoalesced(t), distr_block.num_uncoalesced(t) + ) + + def test_quantile(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + q = np.linspace(0, 1, 11) + np.testing.assert_allclose(distr_noblock.quantile(q), distr_block.quantile(q)) + + def test_mean(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = distr_noblock.tables[0].time[-1] / 2 + np.testing.assert_allclose( + distr_noblock.mean(since=t), distr_block.mean(since=t) + ) + + def test_coalescence_rate_in_intervals(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.coalescence_rate_in_intervals(t), + distr_block.coalescence_rate_in_intervals(t), + ) + + def test_coalescence_probability_in_intervals(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.coalescence_probability_in_intervals(t), + distr_block.coalescence_probability_in_intervals(t), + ) class TestCoalescenceTimeDistributionRunningUpdate(TestCoalescenceTimeDistribution): @@ -957,23 +1101,23 @@ class TestCoalescenceTimeDistributionRunningUpdate(TestCoalescenceTimeDistributi When traversing trees, weights are updated for nodes whose descendant subtree has changed. This is done by taking the parents of added edges, and tracing ancestors down to the root. This class tests that this "running update" - scheme produces the same results as calculating weights separately for each - tree. + scheme produces the correct result. """ - # TODO: when missing data handling is implemented, test here - - def coalescence_time_distribution(self, ts): - brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) - smp_set = np.arange(0, ts.num_samples) - smp_set = np.floor_divide((len(brk) - 1) * smp_set, ts.num_samples) - smp_set = [np.where(smp_set == i)[0].tolist() for i in range(len(brk) - 1)] + def coalescence_time_distribution_running(self, ts, brk, sets=2): + n = ts.num_samples // sets + smp_set = [list(range(i, i + n)) for i in range(0, ts.num_samples, n)] distr = ts.coalescence_time_distribution( sample_sets=smp_set, window_breaks=brk, weight_func="trio_first_coalescence_events", span_normalise=False, ) + return distr + + def coalescence_time_distribution_split(self, ts, brk, sets=2): + n = ts.num_samples // sets + smp_set = [list(range(i, i + n)) for i in range(0, ts.num_samples, n)] distr_by_win = [] for left, right in zip(brk[:-1], brk[1:]): ts_trim = ts.keep_intervals([[left, right]]).trim() @@ -981,20 +1125,65 @@ def coalescence_time_distribution(self, ts): ts_trim.coalescence_time_distribution( sample_sets=smp_set, weight_func="trio_first_coalescence_events", + span_normalise=False, ) ] - return distr, distr_by_win + return distr_by_win def test_many_edge_diffs(self): + """ + Test that ts windowed by tree gives same result as set of single trees. + """ ts = self.ts_many_edge_diffs() - distr, distr_win = self.coalescence_time_distribution(ts) + brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) + distr = self.coalescence_time_distribution_running(ts, brk) + distr_win = self.coalescence_time_distribution_split(ts, brk) time_breaks = np.array([np.inf]) updt = distr.num_coalesced(time_breaks) sepr = np.zeros(updt.shape) for i, d in enumerate(distr_win): c = d.num_coalesced(time_breaks) sepr[:, :, i] = c.reshape((c.shape[0], 1)) - assert np.allclose(sepr, updt) + np.testing.assert_allclose(sepr, updt) + + def test_missing_trees(self): + """ + Test that ts with half of each tree masked gives same result as unmasked ts. + """ + ts = self.ts_many_edge_diffs() + brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) + mask = np.array( + [ + [tr.interval.left, (tr.interval.right + tr.interval.left) / 2] + for tr in ts.trees() + ] + ) + ts_mask = ts.delete_intervals(mask) + distr = self.coalescence_time_distribution_running(ts, brk) + distr_mask = self.coalescence_time_distribution_running(ts_mask, brk) + time_breaks = np.array([np.inf]) + updt = distr.num_coalesced(time_breaks) + updt_mask = distr_mask.num_coalesced(time_breaks) + np.testing.assert_allclose(updt, updt_mask) + + def test_unary_nodes(self): + """ + Test that ts with unary nodes gives same result as ts with unary nodes removed. + """ + ts = self.ts_many_edge_diffs() + ts_unary = ts.simplify( + samples=list(range(ts.num_samples // 2)), keep_unary=True + ) + ts_nounary = ts.simplify( + samples=list(range(ts.num_samples // 2)), keep_unary=False + ) + brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) + distr_unary = self.coalescence_time_distribution_running(ts_unary, brk) + distr_nounary = self.coalescence_time_distribution_running(ts_nounary, brk) + time_breaks = np.array([np.inf]) + updt_unary = distr_unary.num_coalesced(time_breaks) + updt_nounary = distr_nounary.num_coalesced(time_breaks) + np.testing.assert_allclose(updt_unary, updt_nounary) class TestSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -1016,24 +1205,39 @@ class TestSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): Uniform weights on nodes summed over trees, weighted by tree span """ - def coalescence_time_distribution(self): + def coalescence_time_distribution(self, mask_half_of_each_tree=False): + """ + Methods should give the same result if half of each tree is masked, + because "span weights" are normalised using the accessible (nonmissing) + portion of the tree sequence. + """ ts = self.ts_two_trees_four_leaves() + if mask_half_of_each_tree: + mask = np.array( + [ + [t.interval.left, (t.interval.right + t.interval.left) / 2] + for t in ts.trees() + ] + ) + ts = ts.delete_intervals(mask) distr = ts.coalescence_time_distribution( span_normalise=True, ) return distr - def test_weights(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_weights(self, with_missing_data): w = np.array([[0, 0.12, 1.0, 0.88, 1.0]]).T - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) - def test_cum_weights(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_cum_weights(self, with_missing_data): c = np.array([[0, 0.12, 1.12, 2.00, 3.00]]).T - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) and np.allclose(c, tc) + np.testing.assert_allclose(c, tc) class TestWindowedSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -1058,9 +1262,19 @@ class TestWindowedSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribu """ @tests.cached_example - def coalescence_time_distribution(self): + def coalescence_time_distribution(self, mask_half_of_each_tree=False): + """ + Methods should give the same result if half of each tree is masked, + because "span weights" are normalised using the accessible (nonmissing) + portion of the tree sequence. + """ ts = self.ts_two_trees_four_leaves() gen_breaks = np.array([0.0, 0.5, 1.0]) + if mask_half_of_each_tree: + breaks = [i for i in ts.breakpoints()] + breaks = np.unique(np.concatenate([breaks, gen_breaks])) + mask = np.array([[a, (a + b) / 2] for a, b in zip(breaks[:-1], breaks[1:])]) + ts = ts.keep_intervals(mask) distr = ts.coalescence_time_distribution( window_breaks=gen_breaks, blocks_per_window=2, @@ -1068,26 +1282,32 @@ def coalescence_time_distribution(self): ) return distr - def test_time(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_time(self, with_missing_data): t1 = np.array([0.0, 0.59, 0.73, 1.74]) t2 = np.array([0.0, 0.54, 0.59, 0.59, 0.73, 1.74, 1.74]) - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tt1 = distr.tables[0].time tt2 = distr.tables[1].time - assert np.allclose(t1, tt1) and np.allclose(t2, tt2) + np.testing.assert_allclose(t1, tt1) + np.testing.assert_allclose(t2, tt2) - def test_block(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_block(self, with_missing_data): b1 = np.array([0, 0, 0, 0]) b2 = np.array([0, 1, 0, 1, 0, 0, 1]) - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tb1 = distr.tables[0].block tb2 = distr.tables[1].block - assert np.allclose(b1, tb1) and np.allclose(b2, tb2) + np.testing.assert_allclose(b1, tb1) + np.testing.assert_allclose(b2, tb2) - def test_weights(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_weights(self, with_missing_data): w1 = np.array([[0, 1.0, 1.0, 1.0]]).T w2 = np.array([[0, 0.24, 0.76, 0.24, 0.76, 0.76, 0.24]]).T - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tw1 = distr.tables[0].weights tw2 = distr.tables[1].weights - assert np.allclose(w1, tw1) and np.allclose(w2, tw2) + np.testing.assert_allclose(w1, tw1) + np.testing.assert_allclose(w2, tw2) diff --git a/python/tskit/stats.py b/python/tskit/stats.py index 3e161662b1..a972b01454 100644 --- a/python/tskit/stats.py +++ b/python/tskit/stats.py @@ -215,24 +215,49 @@ def resample_blocks(self, block_multiplier): ) if self.cum_weights[-1, i] > 0: self.quantile[:, i] = self.cum_weights[:, i] / self.cum_weights[-1, i] + else: + self.quantile[:, i] = np.nan class CoalescenceTimeDistribution: """ Class to precompute a table of sorted/weighted node times, from which to calculate the empirical distribution function and estimate coalescence rates in time windows. + + To compute weights efficiently requires an update operation of the form: + + ``output[parent], state[parent] = update(state[children])`` + + where ``output`` are the weights associated with the node, and ``state`` + are values that are needed to compute ``output`` that are recursively + calculated along the tree. The value of ``state`` on the leaves is + initialized via, + + ``state[sample] = initialize(sample, sample_sets)`` """ @staticmethod - def _count_coalescence_events(node, tree, sample_sets): - # TODO this will count unary nodes: should it count nodes - # with >1 child instead? - return np.array([1], dtype=np.int32) + def _count_coalescence_events(): + """ + Count the number of samples that coalesce in ``node``, within each + set of samples in ``sample_sets``. + """ + + def initialize(node, sample_sets): + singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) + return (singles,) + + def update(singles_per_child): + singles = np.sum(singles_per_child, axis=0, keepdims=True) + is_ancestor = (singles > 0).astype(np.float64) + return is_ancestor, (singles,) + + return (initialize, update) @staticmethod - def _count_pair_coalescence_events(node, tree, sample_sets): + def _count_pair_coalescence_events(): """ - Count the number of pairs that coalesce in node, within and between the + Count the number of pairs that coalesce in ``node``, within and between the sets of samples in ``sample_sets``. The count of pairs with members that belong to sets :math:`a` and :math:`b` is: @@ -246,30 +271,29 @@ def _count_pair_coalescence_events(node, tree, sample_sets): correspond to counts of pairs with set labels ``[(0,0), (0,1), (1,1)]``. """ - # TODO needs to be optimized, use np.intersect1d - children = tree.children(node) - samples_per_child = [set(list(tree.samples(c))) for c in children] - sample_counts = np.zeros((len(sample_sets), len(children)), dtype=np.int32) - for i, s1 in enumerate(samples_per_child): - for a, s2 in enumerate([set(s) for s in sample_sets]): - sample_counts[a, i] = len(s1 & s2) - - pair_counts = [] - for a, b in itertools.combinations_with_replacement( - range(sample_counts.shape[0]), 2 - ): - count = 0 - for i, j in itertools.combinations(range(sample_counts.shape[1]), 2): - count += ( - sample_counts[a, i] * sample_counts[b, j] - + sample_counts[a, j] * sample_counts[b, i] - ) / (1 + int(a == b)) - pair_counts.append(count) - - return np.array(pair_counts, dtype=np.int32) + def initialize(node, sample_sets): + singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) + return (singles,) + + def update(singles_per_child): + C = singles_per_child.shape[0] # number of children + S = singles_per_child.shape[1] # number of sample sets + singles = np.sum(singles_per_child, axis=0, keepdims=True) + pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) + for a, b in itertools.combinations(range(C), 2): + for i, (j, k) in enumerate( + itertools.combinations_with_replacement(range(S), 2) + ): + pairs[0, i] += ( + singles_per_child[a, j] * singles_per_child[b, k] + + singles_per_child[a, k] * singles_per_child[b, j] + ) / (1 + int(j == k)) + return pairs, (singles,) + + return (initialize, update) @staticmethod - def _count_trio_first_coalescence_events(node, tree, sample_sets): + def _count_trio_first_coalescence_events(): """ Count the number of pairs that coalesce in node with an outgroup, within and between the sets of samples in ``sample_sets``. In other @@ -290,88 +314,160 @@ def _count_trio_first_coalescence_events(node, tree, sample_sets): correspond to counts of pairs with set labels, ``[((0,0),0), ((0,0),1), ..., ((0,1),0), ((0,1),1), ...]``. """ - samples = list(tree.samples(node)) - outg_counts = [len(s) - len(np.intersect1d(samples, s)) for s in sample_sets] - pair_counts = CoalescenceTimeDistribution._count_pair_coalescence_events( - node, tree, sample_sets - ) - trio_counts = [] - for i in pair_counts: - for j in outg_counts: - trio_counts.append(i * j) - return np.array(trio_counts, dtype=np.int32) - def _update_weights_by_edge_diff(self, tree, edge_diff, running_weights): + def initialize(node, sample_sets): + S = len(sample_sets) + totals = np.array([[len(s) for s in sample_sets]], dtype=np.float64) + singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) + pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) + return ( + totals, + singles, + pairs, + ) + + def update(totals_per_child, singles_per_child, pairs_per_child): + C = totals_per_child.shape[0] # number of children + S = totals_per_child.shape[1] # number of sample sets + totals = np.mean(totals_per_child, axis=0, keepdims=True) + singles = np.sum(singles_per_child, axis=0, keepdims=True) + pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) + for a, b in itertools.combinations(range(C), 2): + pair_iterator = itertools.combinations_with_replacement(range(S), 2) + for i, (j, k) in enumerate(pair_iterator): + pairs[0, i] += ( + singles_per_child[a, j] * singles_per_child[b, k] + + singles_per_child[a, k] * singles_per_child[b, j] + ) / (1 + int(j == k)) + outgr = totals - singles + trios = np.zeros((1, pairs.size * outgr.size), dtype=np.float64) + trio_iterator = itertools.product(range(pairs.size), range(outgr.size)) + for i, (j, k) in enumerate(trio_iterator): + trios[0, i] += pairs[0, j] * outgr[0, k] + return trios, ( + totals, + singles, + pairs, + ) + + return (initialize, update) + + def _update_running_with_edge_diff( + self, tree, edge_diff, running_output, running_state, running_index + ): """ - Update ``running_weights`` to reflect ``tree`` using edge differences - ``edge_diff`` with the previous tree. + Update ``running_output`` and ``running_state`` to reflect ``tree``, + using edge differences ``edge_diff`` with the previous tree. + The dict ``running_index`` maps node IDs onto rows of the running arrays. """ assert edge_diff.interval == tree.interval - # nodes that have been removed from tree - removed = {i.child for i in edge_diff.edges_out if tree.is_isolated(i.child)} - # TODO: What if sample is removed from tree? In that case should all - # nodes be updated for trio first coalescences? - - # nodes where descendant subtree has been altered - modified = {i.parent for i in edge_diff.edges_in} - for i in copy.deepcopy(modified): - while tree.parent(i) != tskit.NULL and not tree.parent(i) in modified: + # empty rows in the running arrays + available_rows = {i for i in range(self.running_array_size)} + available_rows -= set(running_index.values()) + + # find internal nodes that have been removed from tree or are unary + removed_nodes = set() + for i in edge_diff.edges_out: + for j in [i.child, i.parent]: + if tree.num_children(j) < 2 and not tree.is_sample(j): + removed_nodes.add(j) + + # find non-unary nodes where descendant subtree has been altered + modified_nodes = { + i.parent for i in edge_diff.edges_in if tree.num_children(i.parent) > 1 + } + for i in copy.deepcopy(modified_nodes): + while tree.parent(i) != tskit.NULL and not tree.parent(i) in modified_nodes: i = tree.parent(i) - modified.add(i) + if tree.num_children(i) > 1: + modified_nodes.add(i) + + # clear running state/output for nodes that are no longer in tree + for i in removed_nodes: + if i in running_index: + running_state[running_index[i], :] = 0 + running_output[running_index[i], :] = 0 + available_rows.add(running_index.pop(i)) + + # recalculate state/output for nodes whose descendants have changed + for i in sorted(modified_nodes, key=lambda node: tree.time(node)): + children = [] + for c in tree.children(i): # skip unary children + while tree.num_children(c) == 1: + (c,) = tree.children(c) + children.append(c) + child_index = [running_index[c] for c in children] + + inputs = ( + running_state[child_index][:, state_index] + for state_index in self.state_indices + ) + output, state = self._update(*inputs) - # recalculate weights for current tree - for i in removed: - running_weights[i, :] = 0 - for i in modified: - running_weights[i, :] = self.weight_func(i, tree, self.sample_sets) - self.weight_func_evals += len(modified) + # update running state/output arrays + if i not in running_index: + running_index[i] = available_rows.pop() + running_output[running_index[i], :] = output + for state_index, x in zip(self.state_indices, state): + running_state[running_index[i], state_index] = x + + # track the number of times the weight function was called + self.weight_func_evals += len(modified_nodes) def _build_ecdf_table_for_window( - self, left, right, tree, edge_diffs, running_weights + self, + left, + right, + tree, + edge_diffs, + running_output, + running_state, + running_index, ): """ - Construct ECDF table for genomic interval [left, right]. Update ``tree``, - ``edge_diffs``, and ``running_weights`` for input for next window. Trees are - counted as belonging to any interval with which they overlap, and thus - can be used in several intervals. Thus, the concatenation of ECDF - tables across multiple intervals is not the same as the ECDF table - for the union of those intervals. Trees within intervals are chunked - into roughly equal-sized blocks for bootstrapping. + Construct ECDF table for genomic interval [left, right]. Update + ``tree``; ``edge_diffs``; and ``running_output``, ``running_state``, + `running_idx``; for input for next window. Trees are counted as + belonging to any interval with which they overlap, and thus can be used + in several intervals. Thus, the concatenation of ECDF tables across + multiple intervals is not the same as the ECDF table for the union of + those intervals. Trees within intervals are chunked into roughly + equal-sized blocks for bootstrapping. """ assert tree.interval.left <= left and right > left + # TODO: if bootstrapping, block span needs to be tracked + # and used to renormalise each replicate. This should be + # done by the bootstrapping machinery, not here. + # assign trees in window to equal-sized blocks with unique id - other_tree = tree.copy() - # TODO: is a full copy of the tree needed, given that the original is - # mutated below? - if right >= other_tree.tree_sequence.sequence_length: - other_tree.last() - else: - # other_tree.seek(right) won't work if `right` is recomb breakpoint - while other_tree.interval.right < right: - other_tree.next() - tree_idx = np.arange(tree.index, other_tree.index + 1) - tree.index tree_offset = tree.index + if right >= tree.tree_sequence.sequence_length: + tree.last() + else: + # tree.seek(right) won't work if `right` is recomb breakpoint + while tree.interval.right < right: + tree.next() + tree_idx = np.arange(tree_offset, tree.index + 1) - tree_offset num_blocks = min(self.num_blocks, len(tree_idx)) tree_blocks = np.floor_divide(num_blocks * tree_idx, len(tree_idx)) # calculate span weights - # TODO: if bootstrapping, does block span need to be tracked - # and used to renormalise each replicate? - other_tree.seek(tree.interval.left) - tree_span = [ - min(other_tree.interval.right, right) - max(other_tree.interval.left, left) - ] - while other_tree.index < tree_offset + tree_idx[-1]: - other_tree.next() + tree.seek_index(tree_offset) + tree_span = [min(tree.interval.right, right) - max(tree.interval.left, left)] + while tree.index < tree_offset + tree_idx[-1]: + tree.next() tree_span.append( - min(other_tree.interval.right, right) - - max(other_tree.interval.left, left) + min(tree.interval.right, right) - max(tree.interval.left, left) ) - tree_span = np.array(tree_span) / sum(tree_span) + tree_span = np.array(tree_span) + total_span = np.sum(tree_span) + assert np.isclose( + total_span, min(right, tree.tree_sequence.sequence_length) - left + ) # storage if using single window, block for entire tree sequence buffer_size = self.buffer_size @@ -381,49 +477,64 @@ def _build_ecdf_table_for_window( weights = np.zeros((table_size, self.num_weights)) # assemble table of coalescence times in window + num_record = 0 + accessible_span = 0.0 + span_weight = 1.0 indices = np.zeros(tree.tree_sequence.num_nodes, dtype=np.int32) - 1 last_block = np.zeros(tree.tree_sequence.num_nodes, dtype=np.int32) - 1 - num_record = 0 + tree.seek_index(tree_offset) while tree.index != tskit.NULL: if tree.interval.right > left: current_block = tree_blocks[tree.index - tree_offset] if self.span_normalise: - span_weight = tree_span[tree.index - tree_offset] - else: - span_weight = 1.0 - nodes_in_tree = np.array( - [i for i in tree.nodes() if tree.is_internal(i)] - ) - # TODO this will fail if all nodes are isolated (masked tree) - nodes_to_add = nodes_in_tree[ - np.where(last_block[nodes_in_tree] != current_block) - ] - if len(nodes_to_add) > 0: - idx = np.arange(num_record, num_record + len(nodes_to_add)) - last_block[nodes_to_add] = current_block - indices[nodes_to_add] = idx - if table_size < num_record + len(nodes_to_add): - table_size += buffer_size - time = np.pad(time, (0, buffer_size)) - block = np.pad(block, (0, buffer_size)) - weights = np.pad(weights, ((0, buffer_size), (0, 0))) - time[idx] = [tree.time(i) for i in nodes_to_add] - block[idx] = current_block - num_record += len(nodes_to_add) - weights[indices[nodes_in_tree], :] += ( - span_weight * running_weights[nodes_in_tree, :] + span_weight = tree_span[tree.index - tree_offset] / total_span + + # TODO: shouldn't need to loop over all keys (nodes) for every tree + internal_nodes = np.array( + [i for i in running_index.keys() if not tree.is_sample(i)], + dtype=np.int32, ) + if internal_nodes.size > 0: + accessible_span += tree_span[tree.index - tree_offset] + rows_in_running = np.array( + [running_index[i] for i in internal_nodes], dtype=np.int32 + ) + nodes_to_add = internal_nodes[ + last_block[internal_nodes] != current_block + ] + if nodes_to_add.size > 0: + table_idx = np.arange( + num_record, num_record + len(nodes_to_add) + ) + last_block[nodes_to_add] = current_block + indices[nodes_to_add] = table_idx + if table_size < num_record + len(nodes_to_add): + table_size += buffer_size + time = np.pad(time, (0, buffer_size)) + block = np.pad(block, (0, buffer_size)) + weights = np.pad(weights, ((0, buffer_size), (0, 0))) + time[table_idx] = [tree.time(i) for i in nodes_to_add] + block[table_idx] = current_block + num_record += len(nodes_to_add) + weights[indices[internal_nodes], :] += ( + span_weight * running_output[rows_in_running, :] + ) + if tree.interval.right < right: # if current tree does not cross window boundary, move to next tree.next() - self._update_weights_by_edge_diff( - tree, next(edge_diffs), running_weights + self._update_running_with_edge_diff( + tree, next(edge_diffs), running_output, running_state, running_index ) else: # use current tree as initial tree for next window break + # reweight span so that weights are averaged over nonmissing trees + if self.span_normalise: + weights *= total_span / accessible_span + return CoalescenceTimeTable(time, block, weights) def _generate_ecdf_tables(self, ts, window_breaks): @@ -437,11 +548,35 @@ def _generate_ecdf_tables(self, ts, window_breaks): tree = ts.first() edge_diffs = ts.edge_diffs() - running_weights = np.zeros((ts.num_nodes, self.num_weights)) - self._update_weights_by_edge_diff(tree, next(edge_diffs), running_weights) + + # initialize running arrays for first tree + running_index = {i: n for i, n in enumerate(tree.samples())} + running_output = np.zeros( + (self.running_array_size, self.num_weights), + dtype=np.float64, + ) + running_state = np.zeros( + (self.running_array_size, self.num_states), + dtype=np.float64, + ) + for node in tree.samples(): + state = self._initialize(node, self.sample_sets) + for state_index, x in zip(self.state_indices, state): + running_state[running_index[node], state_index] = x + + self._update_running_with_edge_diff( + tree, next(edge_diffs), running_output, running_state, running_index + ) + for left, right in zip(window_breaks[:-1], window_breaks[1:]): yield self._build_ecdf_table_for_window( - left, right, tree, edge_diffs, running_weights + left, + right, + tree, + edge_diffs, + running_output, + running_state, + running_index, ) def __init__( @@ -463,18 +598,47 @@ def __init__( self.sample_sets = sample_sets if weight_func is None or weight_func == "coalescence_events": - self.weight_func = self._count_coalescence_events + self._initialize, self._update = self._count_coalescence_events() elif weight_func == "pair_coalescence_events": - self.weight_func = self._count_pair_coalescence_events + self._initialize, self._update = self._count_pair_coalescence_events() elif weight_func == "trio_first_coalescence_events": - self.weight_func = self._count_trio_first_coalescence_events + self._initialize, self._update = self._count_trio_first_coalescence_events() else: - assert callable(weight_func) - self.weight_func = weight_func - _weight_func_eval = self.weight_func(0, ts.first(), self.sample_sets) - assert isinstance(_weight_func_eval, np.ndarray) - assert _weight_func_eval.ndim == 1 - self.num_weights = len(_weight_func_eval) + # user supplies pair of callables ``(initialize, update)`` + assert isinstance(weight_func, tuple) + assert len(weight_func) == 2 + self._initialize, self._update = weight_func + assert callable(self._initialize) + assert callable(self._update) + + # check initialization operation + _state = self._initialize(0, self.sample_sets) + assert isinstance(_state, tuple) + self.num_states = 0 + self.state_indices = [] + for x in _state: + # ``assert is_row_vector(x)`` + assert isinstance(x, np.ndarray) + assert x.ndim == 2 + assert x.shape[0] == 1 + index = list(range(self.num_states, self.num_states + x.size)) + self.state_indices.append(index) + self.num_states += x.size + + # check update operation + _weights, _state = self._update(*_state) + assert isinstance(_state, tuple) + for state_index, x in zip(self.state_indices, _state): + # ``assert is_row_vector(x, len(state_index))`` + assert isinstance(x, np.ndarray) + assert x.ndim == 2 + assert x.shape[0] == 1 + assert x.size == len(state_index) + # ``assert is_row_vector(_weights)`` + assert isinstance(_weights, np.ndarray) + assert _weights.ndim == 2 + assert _weights.shape[0] == 1 + self.num_weights = _weights.size if window_breaks is None: window_breaks = np.array([0.0, ts.sequence_length]) @@ -499,6 +663,7 @@ def __init__( self.span_normalise = span_normalise self.buffer_size = ts.num_nodes + self.running_array_size = ts.num_samples * 2 - 1 # assumes no unary nodes self.weight_func_evals = 0 self.tables = [table for table in self._generate_ecdf_tables(ts, window_breaks)] @@ -533,13 +698,44 @@ def ecdf(self, times): values[:, :, k] = table.quantile[indices, :].T return values - # TODO - # - # def quantile(self, times): - # """ - # Return interpolated quantiles of coalescence times, using the same - # approach as numpy.quantile(..., method="linear") - # """ + def quantile(self, quantiles): + """ + Return interpolated quantiles of weighted coalescence times. + """ + + assert isinstance(quantiles, np.ndarray) + assert quantiles.ndim == 1 + assert np.all(np.logical_and(quantiles >= 0, quantiles <= 1)) + + values = np.empty((self.num_weights, quantiles.size, self.num_windows)) + values[:] = np.nan + for k, table in enumerate(self.tables): + # retrieve ECDF for each unique timepoint in table + last_index = np.flatnonzero(table.time[:-1] != table.time[1:]) + time = np.append(table.time[last_index], table.time[-1]) + ecdf = np.append( + table.quantile[last_index, :], table.quantile[[-1]], axis=0 + ) + for i in range(self.num_weights): + if not np.isnan(ecdf[-1, i]): + # interpolation requires strictly increasing arguments, so + # retrieve leftmost x for step-like F(x), including F(0) = 0. + assert ecdf[-1, i] == 1.0 + assert ecdf[0, i] == 0.0 + delta = ecdf[1:, i] - ecdf[:-1, i] + first_index = 1 + np.flatnonzero(delta > 0) + + n_eff = first_index.size + weight = delta[first_index - 1] + cum_weight = np.roll(ecdf[first_index, i], 1) + cum_weight[0] = 0 + midpoint = np.arange(n_eff) * weight + (n_eff - 1) * cum_weight + assert midpoint[0] == 0 + assert midpoint[-1] == n_eff - 1 + values[i, :, k] = np.interp( + quantiles * (n_eff - 1), midpoint, time[first_index] + ) + return values def num_coalesced(self, times): """ @@ -597,11 +793,12 @@ def mean(self, since=0.0): values[:, k] = np.nan else: for i in range(self.num_weights): - if table.cum_weights[-1, i] > 0: - multiplier = table.block_multiplier[table.block[index:]] + multiplier = table.block_multiplier[table.block[index:]] + weights = table.weights[index:, i] * multiplier + if np.any(weights > 0): values[i, k] = np.average( table.time[index:] - since, - weights=table.weights[index:, i] * multiplier, + weights=weights, ) return values From c12f384c6c7820f783a59537b6a31ba3e6a141fb Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 24 Feb 2023 10:09:25 +0000 Subject: [PATCH 39/84] Don't fail CI on codecov error --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e7e0183cb6..2da829034e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -160,7 +160,7 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} working-directory: python - fail_ci_if_error: true + fail_ci_if_error: false flags: python-tests name: codecov-umbrella verbose: true From 5915a4675efb4917845ea06c320105bd74f18aff Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 3 Apr 2023 15:48:36 +0100 Subject: [PATCH 40/84] Fix mergify --- .mergify.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.mergify.yml b/.mergify.yml index 2ab3025a55..3c71bac359 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -82,5 +82,4 @@ pull_request_rules: queue: name: default method: rebase - rebase_fallback: none update_method: rebase From 740271646721a4d5ce6a804fde8b5e57bf2e2a35 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 4 Apr 2023 12:27:46 +0100 Subject: [PATCH 41/84] Fix mergify --- .mergify.yml | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/.mergify.yml b/.mergify.yml index 3c71bac359..8b0d980744 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -43,7 +43,6 @@ pull_request_rules: queue: name: default method: rebase - rebase_fallback: none update_method: rebase - name: Remove label after merge @@ -55,31 +54,3 @@ pull_request_rules: remove: - AUTOMERGE-REQUESTED - - name: Automatic dep update - conditions: - - author~=^dependabot(|-preview)\[bot\]$ - - "-merged" - - base=main - - label=dependancy-upgrade - - status-success=Docs - - status-success=Lint - - status-success=Python (3.7, macos-latest) - - status-success=Python (3.9, macos-latest) - - status-success=Python (3.11, macos-latest) - - status-success=Python (3.7, ubuntu-latest) - - status-success=Python (3.9, ubuntu-latest) - - status-success=Python (3.11, ubuntu-latest) - - status-success=Python (3.7, windows-latest) - - status-success=Python (3.9, windows-latest) - - status-success=Python (3.11, windows-latest) - - "status-success=ci/circleci: build" - - "status-success=ci/circleci: build-32" - - status-success=codecov/patch - - status-success=codecov/project/c-tests - - status-success=codecov/project/python-c-tests - - status-success=codecov/project/python-tests - actions: - queue: - name: default - method: rebase - update_method: rebase From b7cf0dc0b56e6989830aa90b9d086c1a96e5e317 Mon Sep 17 00:00:00 2001 From: Peter Ralph Date: Fri, 14 Apr 2023 15:38:14 -0700 Subject: [PATCH 42/84] remove unused variable (#2739) * remove unused variable closes #2738 * remove other tree_index --- c/tskit/trees.c | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 8d202d0163..06cad0e823 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1264,7 +1264,7 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, { int ret = 0; tsk_id_t u, v; - tsk_size_t j, k, tree_index, window_index; + tsk_size_t j, k, window_index; tsk_size_t num_nodes = self->tables->nodes.num_rows; const tsk_id_t num_edges = (tsk_id_t) self->tables->edges.num_rows; const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; @@ -1315,7 +1315,6 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tj = 0; tk = 0; t_left = 0; - tree_index = 0; window_index = 0; while (tj < num_edges || t_left < sequence_length) { while (tk < num_edges && edge_right[O[tk]] == t_left) { @@ -1400,7 +1399,6 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, } /* Move to the next tree */ t_left = t_right; - tree_index++; } tsk_bug_assert(window_index == num_windows); out: From c72109fe2f5b55dff5c95e3b06fe5f8e6116da88 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 18 Apr 2023 12:11:08 +0100 Subject: [PATCH 43/84] Fix codecov --- .circleci/config.yml | 27 ++++++++++++++++++++------- .github/workflows/docs.yml | 2 +- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 7066012b29..606b3f7f61 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,4 +1,6 @@ version: 2.1 +orbs: + codecov: codecov/codecov@3.2.4 commands: setup: @@ -48,12 +50,14 @@ commands: ninja -C build-gcc test - run: - name: Run gcov & upload coverage. + name: Run gcov command: | cd build-gcc find ../c/tskit/*.c -type f -printf "%f\n" | xargs -i gcov -pb libtskit.a.p/tskit_{}.gcno ../c/tskit/{} - cd .. - bash <(curl -s https://codecov.io/bash) -X gcov -X coveragepy -F c-tests + + - codecov/upload: + flags: c-tests + token: CODECOV_TOKEN - run: name: Valgrind for C tests. @@ -116,14 +120,17 @@ commands: python -m pytest -n2 - run: - name: Upload LWT coverage + name: Generate coverage command: | # Make sure the C coverage reports aren't lying around rm -fR build-gcc ls -R cd python/lwt_interface gcov -pb -o ./build/temp.linux*/*.gcno example_c_module.c - bash <(curl -s https://codecov.io/bash) -X gcov -F lwt-tests + + - codecov/upload: + flags: lwt-tests + token: CODECOV_TOKEN - run: name: Run Python tests @@ -132,13 +139,17 @@ commands: python -m pytest --cov=tskit --cov-report=xml --cov-branch -n2 tests/test_lowlevel.py tests/test_tables.py tests/test_file_format.py - run: - name: Upload Python coverage + name: Generate Python coverage command: | # Make sure the C coverage reports aren't lying around rm -fR build-gcc + rm -f python/lwt_interface/*.gcov cd python gcov -pb -o ./build/temp.linux*/*.gcno _tskitmodule.c - bash <(curl -s https://codecov.io/bash) -X gcov -F python-c-tests + + - codecov/upload: + flags: python-c-tests + token: CODECOV_TOKEN - run: name: Build Python package @@ -191,6 +202,8 @@ jobs: key: tskit-32-{{ .Branch }}-v7 paths: - "/home/circleci/.local" + # We need to install curl for the codecov upload. + - run: sudo apt-get install -y curl - compile_and_test workflows: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 27d11ef5fe..db0c1c6d09 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -40,7 +40,7 @@ jobs: id: venv-cache with: path: venv - key: docs-venv-v2-${{ hashFiles(env.REQUIREMENTS) }} + key: docs-venv-v3-${{ hashFiles(env.REQUIREMENTS) }} - name: Create venv and install deps (one by one to avoid conflict errors) if: steps.venv-cache.outputs.cache-hit != 'true' From edb3df259937b7b290c6f71a96f3aed82b8603e0 Mon Sep 17 00:00:00 2001 From: "Kevin R. Thornton" Date: Thu, 15 Dec 2022 10:35:41 -0800 Subject: [PATCH 44/84] Implement efficient initialization of null trees. * Add seeking to a tree by index. Closes #2659 --- c/tests/test_trees.c | 10 +++ c/tskit/trees.c | 152 +++++++++++++++++++++++++++++++-- c/tskit/trees.h | 22 +++-- python/tests/test_highlevel.py | 48 ++++++++++- 4 files changed, 217 insertions(+), 15 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index f0ced8585f..94e33ee487 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -6140,10 +6140,16 @@ test_seek_multi_tree(void) ret = tsk_tree_seek(&t, breakpoints[j], 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, j); + ret = tsk_tree_seek_index(&t, j, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, j); for (k = 0; k < num_trees; k++) { ret = tsk_tree_seek(&t, breakpoints[k], 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, k); + ret = tsk_tree_seek_index(&t, k, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, k); } } @@ -6205,6 +6211,10 @@ test_seek_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); ret = tsk_tree_seek(&t, 11, 0); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); + ret = tsk_tree_seek_index(&t, (tsk_id_t) ts.num_trees, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); + ret = tsk_tree_seek_index(&t, -1, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); tsk_tree_free(&t); tsk_treeseq_free(&ts); diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 06cad0e823..027d8fff91 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -4549,19 +4549,138 @@ tsk_tree_position_in_interval(const tsk_tree_t *self, double x) return self->interval.left <= x && x < self->interval.right; } -int TSK_WARN_UNUSED -tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +/* NOTE: + * + * Notes from Kevin Thornton: + * + * This method inserts the edges for an arbitrary tree + * in linear time and requires no additional memory. + * + * During design, the following alternatives were tested + * (in a combination of rust + C): + * 1. Indexing edge insertion/removal locations by tree. + * The indexing can be done in O(n) time, giving O(1) + * access to the first edge in a tree. We can then add + * edges to the tree in O(e) time, where e is the number + * of edges. This apparoach requires O(n) additional memory + * and is only marginally faster than the implementation below. + * 2. Building an interval tree mapping edge id -> span. + * This approach adds a lot of complexity and wasn't any faster + * than the indexing described above. + */ +static int +tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) { int ret = 0; + tsk_size_t edge; + tsk_id_t p, c, e, j, k, tree_index; const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); - const double t_l = self->interval.left; - const double t_r = self->interval.right; - double distance_left, distance_right; + const tsk_treeseq_t *treeseq = self->tree_sequence; + const tsk_table_collection_t *tables = treeseq->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const tsk_size_t num_edges = tables->edges.num_rows; + const tsk_size_t num_trees = self->tree_sequence->num_trees; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + const double *restrict breakpoints = treeseq->breakpoints; + const tsk_id_t *restrict insertion = tables->indexes.edge_insertion_order; + const tsk_id_t *restrict removal = tables->indexes.edge_removal_order; + + // NOTE: it may be better to get the + // index first and then ask if we are + // searching in the first or last 1/2 + // of trees. + j = -1; + if (x <= L / 2.0) { + for (edge = 0; edge < num_edges; edge++) { + e = insertion[edge]; + if (edge_left[e] > x) { + j = (tsk_id_t) edge; + break; + } + if (x >= edge_left[e] && x < edge_right[e]) { + p = edge_parent[e]; + c = edge_child[e]; + tsk_tree_insert_edge(self, p, c, e); + } + } + } else { + for (edge = 0; edge < num_edges; edge++) { + e = removal[num_edges - edge - 1]; + if (edge_right[e] < x) { + j = (tsk_id_t)(num_edges - edge - 1); + while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { + j++; + } + break; + } + if (x >= edge_left[e] && x < edge_right[e]) { + p = edge_parent[e]; + c = edge_child[e]; + tsk_tree_insert_edge(self, p, c, e); + } + } + } - if (x < 0 || x >= L) { + if (j == -1) { + j = 0; + while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { + j++; + } + } + k = 0; + while (k < (tsk_id_t) num_edges && edge_right[removal[k]] <= x) { + k++; + } + + /* NOTE: tsk_search_sorted finds the first the first + * insertion locatiom >= the query point, which + * finds a RIGHT value for queries not at the left edge. + */ + tree_index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); + if (breakpoints[tree_index] > x) { + tree_index -= 1; + } + self->index = tree_index; + self->interval.left = breakpoints[tree_index]; + self->interval.right = breakpoints[tree_index + 1]; + self->left_index = j; + self->right_index = k; + self->direction = TSK_DIR_FORWARD; + self->num_nodes = tables->nodes.num_rows; + if (tables->sites.num_rows > 0) { + self->sites = treeseq->tree_sites[self->index]; + self->sites_length = treeseq->tree_sites_length[self->index]; + } + + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options) +{ + int ret = 0; + double x; + + if (tree < 0 || tree >= (tsk_id_t) self->tree_sequence->num_trees) { ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; goto out; } + x = self->tree_sequence->breakpoints[tree]; + ret = tsk_tree_seek(self, x, options); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_linear(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +{ + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + const double t_l = self->interval.left; + const double t_r = self->interval.right; + int ret = 0; + double distance_left, distance_right; if (x < t_l) { /* |-----|-----|========|---------| */ @@ -4594,6 +4713,27 @@ tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) return ret; } +int TSK_WARN_UNUSED +tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t options) +{ + int ret = 0; + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + + if (x < 0 || x >= L) { + ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; + goto out; + } + + if (self->index == -1) { + ret = tsk_tree_seek_from_null(self, x, options); + } else { + ret = tsk_tree_seek_linear(self, x, options); + } + +out: + return ret; +} + int TSK_WARN_UNUSED tsk_tree_clear(tsk_tree_t *self) { diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 4a84bf3446..efe9980077 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1192,12 +1192,6 @@ we will have ``position < tree.interval.right``. Seeking to a position currently covered by the tree is a constant time operation. - -.. warning:: - The current implementation of ``seek`` does **not** provide efficient - random access to arbitrary positions along the genome. However, - sequentially seeking in either direction is as efficient as calling - :c:func:`tsk_tree_next` or :c:func:`tsk_tree_prev` directly. @endrst @param self A pointer to an initialised tsk_tree_t object. @@ -1208,6 +1202,22 @@ a constant time operation. */ int tsk_tree_seek(tsk_tree_t *self, double position, tsk_flags_t options); +/** +@brief Seek to a specific tree in a tree sequence. + +@rst +Set the state of this tree to reflect the tree in parent +tree sequence whose index is ``0 <= tree < num_trees``. +@endrst + +@param self A pointer to an initialised tsk_tree_t object. +@param tree The target tree index. +@param options Seek options. Currently unused. Set to 0 for compatibility + with future versions of tskit. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options); + /** @} */ /** diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index c926121aea..ca67e24ce0 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -4598,10 +4598,13 @@ def test_index_from_different_directions(self, index): t2.prev() assert_same_tree_different_order(t1, t2) - def test_seek_0_from_null(self): + @pytest.mark.parametrize("position", [0, 1, 2, 3]) + def test_seek_from_null(self, position): t1, t2 = self.setup() - t1.first() - t2.seek(0) + t1.clear() + t1.seek(position) + t2.first() + t2.seek(position) assert_trees_identical(t1, t2) @pytest.mark.parametrize("index", range(3)) @@ -4654,6 +4657,14 @@ def test_seek_3_from_null(self): t2.seek(3) assert_trees_identical(t1, t2) + def test_seek_3_from_null_prev(self): + t1, t2 = self.setup() + t1.last() + t1.prev() + t2.seek(3) + t2.prev() + assert_trees_identical(t1, t2) + def test_seek_3_from_0(self): t1, t2 = self.setup() t1.last() @@ -4669,6 +4680,37 @@ def test_seek_0_from_3(self): t2.seek(0) assert_trees_identical(t1, t2) + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_mid_null_and_middle(self, ts): + breakpoints = ts.breakpoints(as_array=True) + mid = breakpoints[:-1] + np.diff(breakpoints) / 2 + for index, x in enumerate(mid[:-1]): + t1 = tskit.Tree(ts) + t1.seek(x) + # Also seek to this point manually to make sure we're not + # reusing the seek from null under the hood. + t2 = tskit.Tree(ts) + if index <= ts.num_trees / 2: + while t2.index != index: + t2.next() + else: + while t2.index != index: + t2.prev() + assert t1.index == t2.index + assert np.all(t1.parent_array == t2.parent_array) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_last_then_prev(self, ts): + t1 = tskit.Tree(ts) + t1.seek(ts.sequence_length - 0.00001) + assert t1.index == ts.num_trees - 1 + t2 = tskit.Tree(ts) + t2.prev() + assert_trees_identical(t1, t2) + t1.prev() + t2.prev() + assert_trees_identical(t1, t2) + class TestSeek: @pytest.mark.parametrize("ts", get_example_tree_sequences()) From 7ddcc95afc9b8430749e70b371f0da42d8928bee Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 11 Apr 2023 22:45:20 +0100 Subject: [PATCH 45/84] Use C implementation for seek_index Closes #2696 --- c/CHANGELOG.rst | 7 +++++++ c/tskit/trees.c | 2 +- python/CHANGELOG.rst | 5 +++++ python/_tskitmodule.c | 27 +++++++++++++++++++++++++++ python/tests/test_lowlevel.py | 10 ++++++++++ python/tskit/trees.py | 8 +------- 6 files changed, 51 insertions(+), 8 deletions(-) diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index c1643e1ff1..c8dcb7143f 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -2,6 +2,11 @@ [1.1.2] - 2023-XX-XX -------------------- +**Performance improvements** + +- tsk_tree_seek is now much faster at seeking to arbitrary points along + the sequence from the null tree (:user:`molpopgen`, :pr:`2661`). + **Features** - The struct ``tsk_treeseq_t`` now has the variables ``min_time`` and ``max_time``, @@ -24,6 +29,8 @@ - Add `x_table_keep_rows` methods to provide efficient in-place table subsetting (:user:`jeromekelleher`, :pr:`2700`). +- Add `tsk_tree_seek_index` function + -------------------- [1.1.1] - 2022-07-29 -------------------- diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 027d8fff91..4604579e0b 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -4640,7 +4640,7 @@ tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(optio */ tree_index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); if (breakpoints[tree_index] > x) { - tree_index -= 1; + tree_index--; } self->index = tree_index; self->interval.left = breakpoints[tree_index]; diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 3a81b2aa1c..4450bf06c4 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -2,6 +2,11 @@ [0.5.5] - 2023-01-XX -------------------- +**Performance improvements** + +- Methods like ts.at() which seek to a specified position on the sequence from + a new Tree instance are now much faster (:user:`molpopgen`, :pr:`2661`). + **Features** - Add ``__repr__`` for variants to return a string representation of the raw data diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 2b379ff2b1..30c3e7743b 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10658,6 +10658,29 @@ Tree_seek(Tree *self, PyObject *args) return ret; } +static PyObject * +Tree_seek_index(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t index = 0; + int err; + + if (Tree_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O&", tsk_id_converter, &index)) { + goto out; + } + err = tsk_tree_seek_index(self->tree, index, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + static PyObject * Tree_clear(Tree *self) { @@ -11796,6 +11819,10 @@ static PyMethodDef Tree_methods[] = { .ml_meth = (PyCFunction) Tree_seek, .ml_flags = METH_VARARGS, .ml_doc = "Seeks to the tree at the specified position" }, + { .ml_name = "seek_index", + .ml_meth = (PyCFunction) Tree_seek_index, + .ml_flags = METH_VARARGS, + .ml_doc = "Seeks to the tree at the specified index" }, { .ml_name = "clear", .ml_meth = (PyCFunction) Tree_clear, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 8bbc201b85..7ebe6467eb 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2968,6 +2968,16 @@ def test_seek_errors(self): with pytest.raises(_tskit.LibraryError): tree.seek(bad_pos) + def test_seek_index_errors(self): + ts = self.get_example_tree_sequence() + tree = _tskit.Tree(ts) + for bad_type in ["", "x", {}]: + with pytest.raises(TypeError): + tree.seek_index(bad_type) + for bad_index in [-1, 10**6]: + with pytest.raises(_tskit.LibraryError): + tree.seek_index(bad_index) + def test_root_threshold(self): for ts in self.get_example_tree_sequences(): tree = _tskit.Tree(ts) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 39a1e4c41a..da5a7f9d07 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -820,7 +820,6 @@ def seek_index(self, index): .. include:: substitutions/linear_traversal_warning.rst - :param int index: The tree index to seek to. :raises IndexError: If an index outside the acceptable range is provided. """ @@ -829,12 +828,7 @@ def seek_index(self, index): index += num_trees if index < 0 or index >= num_trees: raise IndexError("Index out of bounds") - # This should be implemented in C efficiently using the indexes. - # No point in complicating the current implementation by trying - # to seek from the correct direction. - self.first() - while self.index != index: - self.next() + self._ll_tree.seek_index(index) def seek(self, position): """ From be7d84f7a014c679c9e303aa7a8145b440f73bd3 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Mon, 17 Apr 2023 13:55:04 +0100 Subject: [PATCH 46/84] Clarify Tree.samples() behaviour w/ isolated nodes Just to make it clear that this still works when a node is not "in" the tree topology. --- python/tskit/trees.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index da5a7f9d07..cb75f711fb 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2181,9 +2181,12 @@ def _sample_generator(self, u): def samples(self, u=None): """ Returns an iterator over the numerical IDs of all the sample nodes in - this tree that are underneath node ``u``. If ``u`` is a sample, it is - included in the returned iterator. If u is not specified, return all - sample node IDs in the tree. + this tree that are underneath the node with ID ``u``. If ``u`` is a sample, + it is included in the returned iterator. If ``u`` is not a sample, it is + possible for the returned iterator to be empty, for example if ``u`` is an + :meth:`isolated` node that is not part of the the current + topology. If u is not specified, return all sample node IDs in the tree + (equivalent to all the sample node IDs in the tree sequence). If the :meth:`TreeSequence.trees` method is called with ``sample_lists=True``, this method uses an efficient algorithm to find From 332d5b7461ac895eedb9fd95c5e4bc4d28a817cb Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 11 May 2023 14:36:11 +0100 Subject: [PATCH 47/84] Fix type to Py_BuildValue --- python/CHANGELOG.rst | 5 +++++ python/_tskitmodule.c | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 4450bf06c4..e0ff7117f1 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -15,6 +15,11 @@ - Add ``keep_rows`` method to table classes to support efficient in-place table subsetting (:user:`jeromekelleher`, :pr:`2700`) +**Bugfixes** + +- Fix `UnicodeDecodeError` when calling `Variant.alleles` on the `emscripten` platform. + (:user:`benjeffery`, :pr:`2754`, :issue:`2737`) + -------------------- [0.5.4] - 2023-01-13 -------------------- diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 30c3e7743b..dea3c03fd9 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -578,7 +578,8 @@ make_alleles(tsk_variant_t *variant) goto out; } for (j = 0; j < variant->num_alleles; j++) { - item = Py_BuildValue("s#", variant->alleles[j], variant->allele_lengths[j]); + item = Py_BuildValue( + "s#", variant->alleles[j], (Py_ssize_t) variant->allele_lengths[j]); if (item == NULL) { Py_DECREF(t); goto out; From 4bef5c17888e0d32839cb4c656411805e88f6efc Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 17 May 2023 16:58:51 +0100 Subject: [PATCH 48/84] Release prep --- c/CHANGELOG.rst | 2 +- c/VERSION.txt | 2 +- c/tskit/core.h | 2 +- python/CHANGELOG.rst | 2 +- python/tests/test_lowlevel.py | 2 +- python/tskit/_version.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index c8dcb7143f..11606a42f2 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -1,5 +1,5 @@ -------------------- -[1.1.2] - 2023-XX-XX +[1.1.2] - 2023-05-17 -------------------- **Performance improvements** diff --git a/c/VERSION.txt b/c/VERSION.txt index 8cfbc905b3..8428158dc5 100644 --- a/c/VERSION.txt +++ b/c/VERSION.txt @@ -1 +1 @@ -1.1.1 \ No newline at end of file +1.1.2 \ No newline at end of file diff --git a/c/tskit/core.h b/c/tskit/core.h index 20ca5881da..b8b9f354ba 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -152,7 +152,7 @@ to the API or ABI are introduced, i.e., the addition of a new function. The library patch version. Incremented when any changes not relevant to the to the API or ABI are introduced, i.e., internal refactors of bugfixes. */ -#define TSK_VERSION_PATCH 1 +#define TSK_VERSION_PATCH 2 /** @} */ /* diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index e0ff7117f1..c98b5d56f9 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,5 +1,5 @@ -------------------- -[0.5.5] - 2023-01-XX +[0.5.5] - 2023-05-17 -------------------- **Performance improvements** diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 7ebe6467eb..70ef08143c 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -3873,7 +3873,7 @@ def test_kastore_version(self): def test_tskit_version(self): version = _tskit.get_tskit_version() - assert version == (1, 1, 1) + assert version == (1, 1, 2) def test_tskit_version_file(self): maj, min_, patch = _tskit.get_tskit_version() diff --git a/python/tskit/_version.py b/python/tskit/_version.py index 71f3c3e882..d36b46f038 100644 --- a/python/tskit/_version.py +++ b/python/tskit/_version.py @@ -1,4 +1,4 @@ # Definitive location for the version number. # During development, should be x.y.z.devN # For beta should be x.y.zbN -tskit_version = "0.5.5.dev0" +tskit_version = "0.5.5" From 70bc1222df8dcf250947c399d1cb8348babeb6b3 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 13 Jun 2023 13:08:24 +0100 Subject: [PATCH 49/84] Remove package test from 32bit due to tricky dependancies --- .circleci/config.yml | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 606b3f7f61..d4b68b468f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -14,7 +14,6 @@ commands: sudo pip install meson pip install numpy==1.18.5 pip install --user -r python/requirements/CI-complete/requirements.txt - ARGO_NET_GIT_FETCH_WITH_CLI=1 pip install twine --user # Remove tskit installed by msprime pip uninstall tskit -y echo 'export PATH=/home/circleci/.local/bin:$PATH' >> $BASH_ENV @@ -151,23 +150,6 @@ commands: flags: python-c-tests token: CODECOV_TOKEN - - run: - name: Build Python package - command: | - cd python - rm -fR build - python setup.py sdist - python setup.py check - python -m twine check dist/*.tar.gz - python -m venv venv - source venv/bin/activate - pip install --upgrade setuptools pip wheel - python setup.py build_ext - python setup.py egg_info - python setup.py bdist_wheel - pip install dist/*.tar.gz - tskit --help - jobs: build: docker: @@ -187,6 +169,28 @@ jobs: paths: - "/home/circleci/.local" - compile_and_test + - run: + name: Install dependencies for wheel test + command: | + ARGO_NET_GIT_FETCH_WITH_CLI=1 pip install twine --user + # Remove tskit installed by msprime + pip uninstall tskit -y + - run: + name: Build Python package + command: | + cd python + rm -fR build + python setup.py sdist + python setup.py check + python -m twine check dist/*.tar.gz + python -m venv venv + source venv/bin/activate + pip install --upgrade setuptools pip wheel + python setup.py build_ext + python setup.py egg_info + python setup.py bdist_wheel + pip install dist/*.tar.gz + tskit --help build-32: docker: From d635c19f584acde7cdeaf62d4ce408f502647ef6 Mon Sep 17 00:00:00 2001 From: duncanMR Date: Tue, 13 Jun 2023 12:21:38 +0100 Subject: [PATCH 50/84] Add method to impute unknown mutation times --- python/CHANGELOG.rst | 8 ++++++++ python/tests/test_highlevel.py | 29 +++++++++++++++++++++++++++++ python/tskit/trees.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index c98b5d56f9..3a851c3e87 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,3 +1,11 @@ +-------------------- +[0.5.X] - 2023-XX-XX +-------------------- + +**Features** + +- Add ``TreeSequence.impute_unknown_mutations_time`` method to return an array of mutation times based on the times of associated nodes (:user:`duncanMR`, :pr:`2760`, :issue:`2758`) + -------------------- [0.5.5] - 2023-05-17 -------------------- diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index ca67e24ce0..ce225f1dd7 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -2596,6 +2596,35 @@ def test_arrays_equal_to_tables(self, ts_fixture): ts.indexes_edge_removal_order, tables.indexes.edge_removal_order ) + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_impute_unknown_mutations_time(self, ts): + # Tests for method='min' + imputed_time = ts.impute_unknown_mutations_time(method="min") + mutations = ts.tables.mutations + nodes_time = ts.nodes_time + table_time = np.zeros(len(mutations)) + + for mut_idx, mut in enumerate(mutations): + if tskit.is_unknown_time(mut.time): + node_time = nodes_time[mut.node] + table_time[mut_idx] = node_time + else: + table_time[mut_idx] = mut.time + + assert np.allclose(imputed_time, table_time, rtol=1e-10, atol=1e-10) + + # Check we have valid times + tables = ts.dump_tables() + tables.mutations.time = imputed_time + tables.sort() + tables.tree_sequence() + + # Test for unallowed methods + with pytest.raises( + ValueError, match="Mutations time imputation method must be chosen" + ): + ts.impute_unknown_mutations_time(method="foobar") + class TestSimplify: # This class was factored out of the old TestHighlevel class 2022-12-13, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index cb75f711fb..ffa7ca20ba 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8270,6 +8270,7 @@ def Tajimas_D(self, sample_sets=None, windows=None, mode="site"): :return: A ndarray with shape equal to (num windows, num statistics). If there is one sample set and windows=None, a numpy scalar is returned. """ + # TODO this should be done in C as we'll want to support this method there. def tjd_func(sample_set_sizes, flattened, **kwargs): n = sample_set_sizes @@ -8882,6 +8883,36 @@ def coalescence_time_distribution( span_normalise=span_normalise, ) + def impute_unknown_mutations_time( + self, + method=None, + ): + """ + Returns an array of mutation times, where any unknown times are + imputed from the times of associated nodes. Not to be confused with + :meth:`TableCollection.compute_mutation_times`, which modifies the + ``time`` column of the mutations table in place. + + :param str method: The method used to impute the unknown mutation times. + Currently only "min" is supported, which uses the time of the node + below the mutation as the mutation time. The "min" method can also + be specified by ``method=None`` (Default: ``None``). + :return: An array of length equal to the number of mutations in the + tree sequence. + """ + allowed_methods = ["min"] + if method is None: + method = "min" + if method not in allowed_methods: + raise ValueError( + f"Mutations time imputation method must be chosen from {allowed_methods}" + ) + if method == "min": + mutations_time = self.mutations_time.copy() + unknown = tskit.is_unknown_time(mutations_time) + mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] + return mutations_time + ############################################ # # Deprecated APIs. These are either already unsupported, or will be unsupported in a From b579afce38e20a6a3eb9d580c34c3705c37d8ab2 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 20 Jun 2023 09:52:23 +0100 Subject: [PATCH 51/84] Bump dependancy cache --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index db0c1c6d09..668470e1f6 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -40,7 +40,7 @@ jobs: id: venv-cache with: path: venv - key: docs-venv-v3-${{ hashFiles(env.REQUIREMENTS) }} + key: docs-venv-v4-${{ hashFiles(env.REQUIREMENTS) }} - name: Create venv and install deps (one by one to avoid conflict errors) if: steps.venv-cache.outputs.cache-hit != 'true' From 93cd81f6b7886d55119f4b9e32e7ad2cdffa5da8 Mon Sep 17 00:00:00 2001 From: PalashLalwani Date: Mon, 19 Jun 2023 17:09:42 +0530 Subject: [PATCH 52/84] added nodes time refactor code for issue no. #2766 --- .gitignore | 2 ++ python/tskit/trees.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 32c8ed68d4..fcf4695e11 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ build-gcc python/benchmark/*.trees python/benchmark/*.json python/benchmark/*.html +.venv +.env diff --git a/python/tskit/trees.py b/python/tskit/trees.py index ffa7ca20ba..9ccae3488d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -4561,14 +4561,14 @@ def max_root_time(self): raise ValueError( "max_root_time is not defined in a tree sequence with 0 samples" ) - ret = max(self.node(u).time for u in self.samples()) + ret = max(self.nodes_time[u] for u in self.samples()) if self.num_edges > 0: # Edges are guaranteed to be listed in parent-time order, so we can get the # last one to get the oldest root edge = self.edge(self.num_edges - 1) # However, we can have situations where there is a sample older than a # 'proper' root - ret = max(ret, self.node(edge.parent).time) + ret = max(ret, self.nodes_time[edge.parent]) return ret def migrations(self): From 3d4fc5194937e19fb2817dd33a8379e366c9d07b Mon Sep 17 00:00:00 2001 From: astheeggeggs Date: Wed, 31 Aug 2022 16:57:54 +0100 Subject: [PATCH 53/84] added forwards backwards testing Added forwards backwards testing and now include missingness appropriately added missingness to diploid LS added some fixes for flake errors added missingness to diploid viterbi changed test_genotype_matching_fb.py remove stray print removed caps for bool EQUAL_BOTH_HOM etc Removed caps for EQUAL_BOTH_HOM etc in Viterbi removed unused imported function --- python/tests/test_genotype_matching_fb.py | 128 ++++++++++++------ .../tests/test_genotype_matching_viterbi.py | 74 ++++++---- 2 files changed, 133 insertions(+), 69 deletions(-) diff --git a/python/tests/test_genotype_matching_fb.py b/python/tests/test_genotype_matching_fb.py index 984b3ce13a..248382e913 100644 --- a/python/tests/test_genotype_matching_fb.py +++ b/python/tests/test_genotype_matching_fb.py @@ -1,4 +1,3 @@ -# Simulation import copy import itertools @@ -14,6 +13,8 @@ REF_HOM_OBS_HET = 1 REF_HET_OBS_HOM = 2 +MISSING = -1 + def mirror_coordinates(ts): """ @@ -411,6 +412,7 @@ def update_probabilities(self, site, genotype_state): ] query_is_het = genotype_state == 1 + query_is_missing = genotype_state == MISSING for st1 in T: u1 = st1.tree_node @@ -444,6 +446,7 @@ def update_probabilities(self, site, genotype_state): match, template_is_het, query_is_het, + query_is_missing, ) # This will ensure that allelic_state[:n] is filled @@ -561,7 +564,14 @@ def compute_normalisation_factor_dict(self): raise NotImplementedError() def compute_next_probability_dict( - self, site_id, p_last, inner_summation, is_match, template_is_het, query_is_het + self, + site_id, + p_last, + inner_summation, + is_match, + template_is_het, + query_is_het, + query_is_missing, ): raise NotImplementedError() @@ -670,41 +680,45 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, ): rho = self.rho[site_id] mu = self.mu[site_id] n = self.ts.num_samples - template_is_hom = np.logical_not(template_is_het) - query_is_hom = np.logical_not(query_is_het) - - EQUAL_BOTH_HOM = np.logical_and( - np.logical_and(is_match, template_is_hom), query_is_hom - ) - UNEQUAL_BOTH_HOM = np.logical_and( - np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom - ) - BOTH_HET = np.logical_and(template_is_het, query_is_het) - REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) - REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) - p_t = ( (rho / n) ** 2 + ((1 - rho) * (rho / n)) * inner_normalisation_factor + (1 - rho) ** 2 * p_last ) - p_e = ( - EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu**2) - + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) - + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu**2) - ) + + if query_is_missing: + p_e = 1 + else: + query_is_hom = np.logical_not(query_is_het) + template_is_hom = np.logical_not(template_is_het) + + equal_both_hom = np.logical_and( + np.logical_and(is_match, template_is_hom), query_is_hom + ) + unequal_both_hom = np.logical_and( + np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom + ) + both_het = np.logical_and(template_is_het, query_is_het) + ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het) + ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom) + + p_e = ( + equal_both_hom * (1 - mu) ** 2 + + unequal_both_hom * (mu**2) + + ref_hom_obs_het * (2 * mu * (1 - mu)) + + ref_het_obs_hom * (mu * (1 - mu)) + + both_het * ((1 - mu) ** 2 + mu**2) + ) return p_t * p_e -# DEV: Sort this class BackwardAlgorithm(LsHmmAlgorithm): """Runs the Li and Stephens forward algorithm.""" @@ -737,29 +751,35 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, ): mu = self.mu[site_id] template_is_hom = np.logical_not(template_is_het) - query_is_hom = np.logical_not(query_is_het) - EQUAL_BOTH_HOM = np.logical_and( - np.logical_and(is_match, template_is_hom), query_is_hom - ) - UNEQUAL_BOTH_HOM = np.logical_and( - np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom - ) - BOTH_HET = np.logical_and(template_is_het, query_is_het) - REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) - REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) - - p_e = ( - EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu**2) - + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) - + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu**2) - ) + if query_is_missing: + p_e = 1 + else: + query_is_hom = np.logical_not(query_is_het) + + equal_both_hom = np.logical_and( + np.logical_and(is_match, template_is_hom), query_is_hom + ) + unequal_both_hom = np.logical_and( + np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom + ) + both_het = np.logical_and(template_is_het, query_is_het) + ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het) + ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom) + + p_e = ( + equal_both_hom * (1 - mu) ** 2 + + unequal_both_hom * (mu**2) + + ref_hom_obs_het * (2 * mu * (1 - mu)) + + ref_het_obs_hom * (mu * (1 - mu)) + + both_het * ((1 - mu) ** 2 + mu**2) + ) + return p_next * p_e @@ -797,6 +817,21 @@ def example_genotypes(self, ts): s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) H = H[:, 2:] + genotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), + ] + + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + genotypes.append(s_tmp) + m = ts.get_num_sites() n = H.shape[1] @@ -804,11 +839,11 @@ def example_genotypes(self, ts): for i in range(m): G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - return H, G, s + return H, G, genotypes def example_parameters_genotypes(self, ts, seed=42): np.random.seed(seed) - H, G, s = self.example_genotypes(ts) + H, G, genotypes = self.example_genotypes(ts) n = H.shape[1] m = ts.get_num_sites() @@ -819,13 +854,16 @@ def example_parameters_genotypes(self, ts, seed=42): e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r, mu + for s in genotypes: + yield n, m, G, s, e, r, mu # Mixture of random and extremes rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - for r, mu in itertools.product(rs, mus): + e = self.genotype_emission(mu, m) + + for s, r, mu in itertools.product(genotypes, rs, mus): r[0] = 0 e = self.genotype_emission(mu, m) yield n, m, G, s, e, r, mu diff --git a/python/tests/test_genotype_matching_viterbi.py b/python/tests/test_genotype_matching_viterbi.py index 89377bdb33..acab5d1c28 100644 --- a/python/tests/test_genotype_matching_viterbi.py +++ b/python/tests/test_genotype_matching_viterbi.py @@ -13,6 +13,8 @@ REF_HOM_OBS_HET = 1 REF_HET_OBS_HOM = 2 +MISSING = -1 + class ValueTransition: """Simple struct holding value transition values.""" @@ -390,6 +392,7 @@ def update_probabilities(self, site, genotype_state): ] query_is_het = genotype_state == 1 + query_is_missing = genotype_state == MISSING for st1 in T: u1 = st1.tree_node @@ -423,6 +426,7 @@ def update_probabilities(self, site, genotype_state): match, template_is_het, query_is_het, + query_is_missing, u1, u2, ) @@ -486,6 +490,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, node_1, node_2, ): @@ -830,6 +835,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, node_1, node_2, ): @@ -841,26 +847,28 @@ def compute_next_probability_dict( double_recombination_required = False single_recombination_required = False - template_is_hom = np.logical_not(template_is_het) - query_is_hom = np.logical_not(query_is_het) - - EQUAL_BOTH_HOM = np.logical_and( - np.logical_and(is_match, template_is_hom), query_is_hom - ) - UNEQUAL_BOTH_HOM = np.logical_and( - np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom - ) - BOTH_HET = np.logical_and(template_is_het, query_is_het) - REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) - REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) - - p_e = ( - EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu**2) - + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) - + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu**2) - ) + if query_is_missing: + p_e = 1 + else: + template_is_hom = np.logical_not(template_is_het) + query_is_hom = np.logical_not(query_is_het) + equal_both_hom = np.logical_and( + np.logical_and(is_match, template_is_hom), query_is_hom + ) + unequal_both_hom = np.logical_and( + np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom + ) + both_het = np.logical_and(template_is_het, query_is_het) + ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het) + ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom) + + p_e = ( + equal_both_hom * (1 - mu) ** 2 + + unequal_both_hom * (mu**2) + + ref_hom_obs_het * (2 * mu * (1 - mu)) + + ref_het_obs_hom * (mu * (1 - mu)) + + both_het * ((1 - mu) ** 2 + mu**2) + ) no_switch = (1 - r) ** 2 + 2 * (r_n * (1 - r)) + r_n**2 single_switch = r_n * (1 - r) + r_n**2 @@ -919,6 +927,21 @@ def example_genotypes(self, ts): s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) H = H[:, 2:] + genotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), + ] + + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + genotypes.append(s_tmp) + m = ts.get_num_sites() n = H.shape[1] @@ -926,11 +949,11 @@ def example_genotypes(self, ts): for i in range(m): G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - return H, G, s + return H, G, genotypes def example_parameters_genotypes(self, ts, seed=42): np.random.seed(seed) - H, G, s = self.example_genotypes(ts) + H, G, genotypes = self.example_genotypes(ts) n = H.shape[1] m = ts.get_num_sites() @@ -941,13 +964,16 @@ def example_parameters_genotypes(self, ts, seed=42): e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r, mu + for s in genotypes: + yield n, m, G, s, e, r, mu # Mixture of random and extremes rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - for r, mu in itertools.product(rs, mus): + e = self.genotype_emission(mu, m) + + for s, r, mu in itertools.product(genotypes, rs, mus): r[0] = 0 e = self.genotype_emission(mu, m) yield n, m, G, s, e, r, mu From 5eb173ab4e45a05b35b8cbbdcfbed83a0b3d1853 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 10 Jul 2023 16:21:15 +0100 Subject: [PATCH 54/84] Fix benchmark CI --- .github/workflows/tests.yml | 6 +++--- python/requirements/benchmark.txt | 9 +++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 python/requirements/benchmark.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2da829034e..473f04e893 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,11 +34,11 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' cache: 'pip' - cache-dependency-path: python/requirements/development.txt + cache-dependency-path: python/requirements/benchmark.txt - name: Install deps - run: pip install -r python/requirements/development.txt + run: pip install -r python/requirements/benchmark.txt - name: Build module run: | cd python diff --git a/python/requirements/benchmark.txt b/python/requirements/benchmark.txt new file mode 100644 index 0000000000..12a0be4060 --- /dev/null +++ b/python/requirements/benchmark.txt @@ -0,0 +1,9 @@ +click +psutil +tqdm +matplotlib +si-prefix +jsonschema +svgwrite +msprime +PyYAML \ No newline at end of file From a3b095f5d38357326185095c395a19bc6a13bdb4 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 17 Feb 2023 09:25:05 +0000 Subject: [PATCH 55/84] Divergence matrix tree-by-tree algorithms Implement the basic version of the divergence matrix operation using tree-by-tree algorithms, and provide interface for parallelising along the genome. --- c/tests/test_stats.c | 345 ++++++++++- c/tests/test_trees.c | 9 +- c/tests/testlib.c | 12 +- c/tests/testlib.h | 4 +- c/tskit/core.c | 9 + c/tskit/core.h | 10 + c/tskit/trees.c | 576 ++++++++++++++++- c/tskit/trees.h | 4 + python/_tskitmodule.c | 76 +++ python/tests/test_divmat.py | 1064 ++++++++++++++++++++++++++++++++ python/tests/test_highlevel.py | 4 +- python/tests/test_lowlevel.py | 20 + python/tests/tsutil.py | 20 +- python/tskit/trees.py | 106 +++- 14 files changed, 2230 insertions(+), 29 deletions(-) create mode 100644 python/tests/test_divmat.py diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 35991288d4..2d5bc97fd5 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -262,6 +262,48 @@ verify_mean_descendants(tsk_treeseq_t *ts) free(C); } +/* Check the divergence matrix by running against the stats API equivalent + * code. NOTE: this will not always be equal in site mode, because of a slightly + * different definition wrt to multiple mutations at a site. + */ +static void +verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t mode) +{ + int ret; + const tsk_size_t n = tsk_treeseq_get_num_samples(ts); + const tsk_id_t *samples = tsk_treeseq_get_samples(ts); + tsk_size_t sample_set_sizes[n]; + tsk_id_t index_tuples[2 * n * n]; + double D1[n * n], D2[n * n]; + tsk_size_t i, j, k; + + for (j = 0; j < n; j++) { + sample_set_sizes[j] = 1; + for (k = 0; k < n; k++) { + index_tuples[2 * (j * n + k)] = (tsk_id_t) j; + index_tuples[2 * (j * n + k) + 1] = (tsk_id_t) k; + } + } + ret = tsk_treeseq_divergence( + ts, n, sample_set_sizes, samples, n * n, index_tuples, 0, NULL, mode, D1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_divergence_matrix(ts, 0, NULL, 0, NULL, mode, D2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < n; j++) { + for (k = 0; k < n; k++) { + i = j * n + k; + /* printf("%d\t%d\t%f\t%f\n", (int) j, (int) k, D1[i], D2[i]); */ + if (j == k) { + CU_ASSERT_EQUAL(D2[i], 0); + } else { + CU_ASSERT_DOUBLE_EQUAL(D1[i], D2[i], 1E-6); + } + } + } +} + typedef struct { int call_count; int error_on; @@ -973,6 +1015,128 @@ test_single_tree_general_stat_errors(void) tsk_treeseq_free(&ts); } +static void +test_single_tree_divergence_matrix(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 6, 6, 2, 0, 6, 6, 6, 6, 0, 4, 6, 6, 4, 0 }; + double D_site[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_internal_samples(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D[16] = { 0, 2, 4, 3, 2, 0, 4, 3, 4, 4, 0, 1, 3, 3, 1, 0 }; + + const char *nodes = "1 0 -1 -1\n" /* 2.00┊ 6 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏━┻━┓ ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 5* ┊ */ + "0 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "1 1 -1 -1\n" /* 0 * * * 1 */ + "0 2 -1 -1\n"; + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n" + "0 1 6 4,5\n"; + /* One mutations per branch so we get the same as the branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n" + "0.5 A\n" + "0.6 A\n"; + const char *mutations = "0 0 T -1\n" + "1 1 T -1\n" + "2 2 T -1\n" + "3 3 T -1\n" + "4 4 T -1\n" + "5 5 T -1\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_multi_root(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 }; + double D_site[16] = { 0, 4, 6, 6, 4, 0, 6, 6, 6, 6, 0, 8, 6, 6, 8, 0 }; + + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" /* 2.00┊ 5 ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "0 2 -1 -1\n"; /* 0 * * * * 1 */ + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n"; + /* Two mutations per branch unit so we get twice branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n"; + const char *mutations = "0 0 B -1\n" + "0 0 C 0\n" + "1 1 B -1\n" + "1 1 C 2\n" + "2 2 B -1\n" + "2 2 C 4\n" + "2 2 D 5\n" + "2 2 E 6\n" + "3 3 B -1\n" + "3 3 C 8\n" + "3 3 D 9\n" + "3 3 E 10\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + tsk_treeseq_free(&ts); +} + static void test_paper_ex_ld(void) { @@ -1592,6 +1756,20 @@ test_paper_ex_afs(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + static void test_nonbinary_ex_ld(void) { @@ -1726,6 +1904,158 @@ test_ld_silent_mutations(void) free(base_ts); } +static void +test_simplest_divergence_matrix(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[4] = { 0, 2, 2, 0 }; + double D_site[4] = { 0, 0, 0, 0 }; + double result[4]; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SPAN_NORMALISE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_POLARISED_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); + + sample_ids[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + sample_ids[0] = 3; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_windows(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[8] = { 0, 1, 1, 0, 0, 1, 1, 0 }; + double D_site[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + double result[8]; + double windows[] = { 0, 0.5, 1 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_site, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 2, windows, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + + windows[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.45; + windows[2] = 1.5; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.55; + windows[2] = 1.0; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_internal_sample(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1, 2 }; + double result[9]; + double D_branch[9] = { 0, 2, 1, 2, 0, 1, 1, 1, 0 }; + double D_site[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_branch, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_site, result); + + tsk_treeseq_free(&ts); +} + +static void +test_multiroot_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, multiroot_ex_nodes, multiroot_ex_edges, NULL, + multiroot_ex_sites, multiroot_ex_mutations, NULL, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + int main(int argc, char **argv) { @@ -1745,6 +2075,11 @@ main(int argc, char **argv) test_single_tree_genealogical_nearest_neighbours }, { "test_single_tree_general_stat", test_single_tree_general_stat }, { "test_single_tree_general_stat_errors", test_single_tree_general_stat_errors }, + { "test_single_tree_divergence_matrix", test_single_tree_divergence_matrix }, + { "test_single_tree_divergence_matrix_internal_samples", + test_single_tree_divergence_matrix_internal_samples }, + { "test_single_tree_divergence_matrix_multi_root", + test_single_tree_divergence_matrix_multi_root }, { "test_paper_ex_ld", test_paper_ex_ld }, { "test_paper_ex_mean_descendants", test_paper_ex_mean_descendants }, @@ -1785,6 +2120,7 @@ main(int argc, char **argv) { "test_paper_ex_f4", test_paper_ex_f4 }, { "test_paper_ex_afs_errors", test_paper_ex_afs_errors }, { "test_paper_ex_afs", test_paper_ex_afs }, + { "test_paper_ex_divergence_matrix", test_paper_ex_divergence_matrix }, { "test_nonbinary_ex_ld", test_nonbinary_ex_ld }, { "test_nonbinary_ex_mean_descendants", test_nonbinary_ex_mean_descendants }, @@ -1798,6 +2134,13 @@ main(int argc, char **argv) { "test_ld_multi_mutations", test_ld_multi_mutations }, { "test_ld_silent_mutations", test_ld_silent_mutations }, + { "test_simplest_divergence_matrix", test_simplest_divergence_matrix }, + { "test_simplest_divergence_matrix_windows", + test_simplest_divergence_matrix_windows }, + { "test_simplest_divergence_matrix_internal_sample", + test_simplest_divergence_matrix_internal_sample }, + { "test_multiroot_divergence_matrix", test_multiroot_divergence_matrix }, + { NULL, NULL }, }; return test_main(tests, argc, argv); diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 94e33ee487..cceb11d6fd 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -5395,7 +5395,6 @@ test_simplify_keep_input_roots_multi_tree(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - tsk_treeseq_dump(&ts, "tmp.trees", 0); ret = tsk_treeseq_simplify( &ts, samples, 2, TSK_SIMPLIFY_KEEP_INPUT_ROOTS, &simplified, NULL); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -7801,7 +7800,7 @@ test_time_uncalibrated(void) tsk_size_t sample_set_sizes[] = { 2, 2 }; tsk_id_t samples[] = { 0, 1, 2, 3 }; tsk_size_t num_samples; - double result[10]; + double result[100]; double *W; double *sigma; @@ -7857,6 +7856,12 @@ test_time_uncalibrated(void) TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, sigma); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TIME_UNCALIBRATED); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, + TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_safe_free(W); tsk_safe_free(sigma); tsk_treeseq_free(&ts); diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 823068d136..043ae5ceab 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -966,6 +966,16 @@ tskit_suite_init(void) return CUE_SUCCESS; } +void +assert_arrays_almost_equal(tsk_size_t len, double *a, double *b) +{ + tsk_size_t j; + + for (j = 0; j < len; j++) { + CU_ASSERT_DOUBLE_EQUAL(a[j], b[j], 1e-9); + } +} + static int tskit_suite_cleanup(void) { diff --git a/c/tests/testlib.h b/c/tests/testlib.h index d042d60b55..69efb14781 100644 --- a/c/tests/testlib.h +++ b/c/tests/testlib.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2021 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -54,6 +54,8 @@ void parse_individuals(const char *text, tsk_individual_table_t *individual_tabl void unsort_edges(tsk_edge_table_t *edges, size_t start); +void assert_arrays_almost_equal(tsk_size_t len, double *a, double *b); + extern const char *single_tree_ex_nodes; extern const char *single_tree_ex_edges; extern const char *single_tree_ex_sites; diff --git a/c/tskit/core.c b/c/tskit/core.c index b1ea25badd..100cc78cad 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -466,6 +466,15 @@ tsk_strerror_internal(int err) ret = "Statistics using branch lengths cannot be calculated when time_units " "is 'uncalibrated'. (TSK_ERR_TIME_UNCALIBRATED)"; break; + case TSK_ERR_STAT_POLARISED_UNSUPPORTED: + ret = "The TSK_STAT_POLARISED option is not supported by this statistic. " + "(TSK_ERR_STAT_POLARISED_UNSUPPORTED)"; + break; + case TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED: + ret = "The TSK_STAT_SPAN_NORMALISE option is not supported by this " + "statistic. " + "(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)"; + break; /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: diff --git a/c/tskit/core.h b/c/tskit/core.h index b8b9f354ba..4d2c95212d 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -675,6 +675,16 @@ Statistics based on branch lengths were attempted when the ``time_units`` were ``uncalibrated``. */ #define TSK_ERR_TIME_UNCALIBRATED -910 +/** +The TSK_STAT_POLARISED option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_POLARISED_UNSUPPORTED -911 +/** +The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 4604579e0b..cd0ad36aa2 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1191,9 +1191,11 @@ tsk_treeseq_mean_descendants(const tsk_treeseq_t *self, * General stats framework ***********************************/ +#define TSK_REQUIRE_FULL_SPAN 1 + static int -tsk_treeseq_check_windows( - const tsk_treeseq_t *self, tsk_size_t num_windows, const double *windows) +tsk_treeseq_check_windows(const tsk_treeseq_t *self, tsk_size_t num_windows, + const double *windows, tsk_flags_t options) { int ret = TSK_ERR_BAD_WINDOWS; tsk_size_t j; @@ -1202,12 +1204,23 @@ tsk_treeseq_check_windows( ret = TSK_ERR_BAD_NUM_WINDOWS; goto out; } - /* TODO these restrictions can be lifted later if we want a specific interval. */ - if (windows[0] != 0) { - goto out; - } - if (windows[num_windows] != self->tables->sequence_length) { - goto out; + if (options & TSK_REQUIRE_FULL_SPAN) { + /* TODO the general stat code currently requires that we include the + * entire tree sequence span. This should be relaxed, so hopefully + * this branch (and the option) can be removed at some point */ + if (windows[0] != 0) { + goto out; + } + if (windows[num_windows] != self->tables->sequence_length) { + goto out; + } + } else { + if (windows[0] < 0) { + goto out; + } + if (windows[num_windows] > self->tables->sequence_length) { + goto out; + } } for (j = 0; j < num_windows; j++) { if (windows[j] >= windows[j + 1]) { @@ -1960,7 +1973,8 @@ tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -2468,7 +2482,7 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); - double default_windows[] = { 0, self->tables->sequence_length }; + const double default_windows[] = { 0, self->tables->sequence_length }; const tsk_size_t num_nodes = self->tables->nodes.num_rows; const tsk_size_t K = num_sample_sets + 1; tsk_size_t j, k, l, afs_size; @@ -2496,7 +2510,8 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -3331,7 +3346,7 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, } ret = tsk_treeseq_init( output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); - /* Once tsk_tree_init has returned ownership of tables is transferred */ + /* Once tsk_treeseq_init has returned ownership of tables is transferred */ tables = NULL; out: if (tables != NULL) { @@ -3460,6 +3475,20 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag * Tree * ======================================================== */ +/* Return the root for the specified node. + * NOTE: no bounds checking is done here. + */ +static tsk_id_t +tsk_tree_get_node_root(const tsk_tree_t *self, tsk_id_t u) +{ + const tsk_id_t *restrict parent = self->parent; + + while (parent[u] != TSK_NULL) { + u = parent[u]; + } + return u; +} + int TSK_WARN_UNUSED tsk_tree_init(tsk_tree_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) { @@ -6009,3 +6038,526 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, } return ret; } + +/* + * Divergence matrix + */ + +typedef struct { + /* Note it's a waste storing the triply linked tree here, but the code + * is written on the assumption of 1-based trees and the algorithm is + * frighteningly subtle, so it doesn't seem worth messing with it + * unless we really need to save some memory */ + tsk_id_t *parent; + tsk_id_t *child; + tsk_id_t *sib; + tsk_id_t *lambda; + tsk_id_t *pi; + tsk_id_t *tau; + tsk_id_t *beta; + tsk_id_t *alpha; +} sv_tables_t; + +static int +sv_tables_init(sv_tables_t *self, tsk_size_t n) +{ + int ret = 0; + + self->parent = tsk_malloc(n * sizeof(*self->parent)); + self->child = tsk_malloc(n * sizeof(*self->child)); + self->sib = tsk_malloc(n * sizeof(*self->sib)); + self->pi = tsk_malloc(n * sizeof(*self->pi)); + self->lambda = tsk_malloc(n * sizeof(*self->lambda)); + self->tau = tsk_malloc(n * sizeof(*self->tau)); + self->beta = tsk_malloc(n * sizeof(*self->beta)); + self->alpha = tsk_malloc(n * sizeof(*self->alpha)); + if (self->parent == NULL || self->child == NULL || self->sib == NULL + || self->lambda == NULL || self->tau == NULL || self->beta == NULL + || self->alpha == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +sv_tables_free(sv_tables_t *self) +{ + tsk_safe_free(self->parent); + tsk_safe_free(self->child); + tsk_safe_free(self->sib); + tsk_safe_free(self->lambda); + tsk_safe_free(self->pi); + tsk_safe_free(self->tau); + tsk_safe_free(self->beta); + tsk_safe_free(self->alpha); + return 0; +} +static void +sv_tables_reset(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + tsk_memset(self->parent, 0, n * sizeof(*self->parent)); + tsk_memset(self->child, 0, n * sizeof(*self->child)); + tsk_memset(self->sib, 0, n * sizeof(*self->sib)); + tsk_memset(self->pi, 0, n * sizeof(*self->pi)); + tsk_memset(self->lambda, 0, n * sizeof(*self->lambda)); + tsk_memset(self->tau, 0, n * sizeof(*self->tau)); + tsk_memset(self->beta, 0, n * sizeof(*self->beta)); + tsk_memset(self->alpha, 0, n * sizeof(*self->alpha)); +} + +static void +sv_tables_convert_tree(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + const tsk_id_t *restrict tsk_parent = tree->parent; + tsk_id_t *restrict child = self->child; + tsk_id_t *restrict parent = self->parent; + tsk_id_t *restrict sib = self->sib; + tsk_size_t j; + tsk_id_t u, v; + + for (j = 0; j < n - 1; j++) { + u = (tsk_id_t) j + 1; + v = tsk_parent[j] + 1; + sib[u] = child[v]; + child[v] = u; + parent[u] = v; + } +} + +#define LAMBDA 0 + +static void +sv_tables_build_index(sv_tables_t *self) +{ + const tsk_id_t *restrict child = self->child; + const tsk_id_t *restrict parent = self->parent; + const tsk_id_t *restrict sib = self->sib; + tsk_id_t *restrict lambda = self->lambda; + tsk_id_t *restrict pi = self->pi; + tsk_id_t *restrict tau = self->tau; + tsk_id_t *restrict beta = self->beta; + tsk_id_t *restrict alpha = self->alpha; + tsk_id_t a, n, p, h; + + p = child[LAMBDA]; + n = 0; + lambda[0] = -1; + while (p != LAMBDA) { + while (true) { + n++; + pi[p] = n; + tau[n] = LAMBDA; + lambda[n] = 1 + lambda[n >> 1]; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + beta[p] = n; + while (true) { + tau[beta[p]] = parent[p]; + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p != LAMBDA) { + h = lambda[n & -pi[p]]; + beta[p] = ((n >> h) | 1) << h; + } else { + break; + } + } + } + } + + /* Begin the second traversal */ + lambda[0] = lambda[n]; + pi[LAMBDA] = 0; + beta[LAMBDA] = 0; + alpha[LAMBDA] = 0; + p = child[LAMBDA]; + while (p != LAMBDA) { + while (true) { + a = alpha[parent[p]] | (beta[p] & -beta[p]); + alpha[p] = a; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + while (true) { + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p == LAMBDA) { + break; + } + } + } + } +} + +static void +sv_tables_build(sv_tables_t *self, tsk_tree_t *tree) +{ + sv_tables_reset(self, tree); + sv_tables_convert_tree(self, tree); + sv_tables_build_index(self); +} + +static tsk_id_t +sv_tables_mrca_one_based(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + const tsk_id_t *restrict lambda = self->lambda; + const tsk_id_t *restrict pi = self->pi; + const tsk_id_t *restrict tau = self->tau; + const tsk_id_t *restrict beta = self->beta; + const tsk_id_t *restrict alpha = self->alpha; + tsk_id_t h, k, xhat, yhat, ell, j, z; + + if (beta[x] <= beta[y]) { + h = lambda[beta[y] & -beta[x]]; + } else { + h = lambda[beta[x] & -beta[y]]; + } + k = alpha[x] & alpha[y] & -(1 << h); + h = lambda[k & -k]; + j = ((beta[x] >> h) | 1) << h; + if (j == beta[x]) { + xhat = x; + } else { + ell = lambda[alpha[x] & ((1 << h) - 1)]; + xhat = tau[((beta[x] >> ell) | 1) << ell]; + } + if (j == beta[y]) { + yhat = y; + } else { + ell = lambda[alpha[y] & ((1 << h) - 1)]; + yhat = tau[((beta[y] >> ell) | 1) << ell]; + } + if (pi[xhat] <= pi[yhat]) { + z = xhat; + } else { + z = yhat; + } + return z; +} + +static tsk_id_t +sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + /* Convert to 1-based indexes and back */ + return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1; +} + +static int +tsk_treeseq_check_node_bounds( + const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows; + + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + if (u < 0 || u >= N) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } +out: + return ret; +} + +static int +tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t options, double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const double *restrict nodes_time = self->tables->nodes.time; + const tsk_size_t n = num_samples; + tsk_size_t i, j, k; + tsk_id_t u, v, w, u_root, v_root; + double tu, tv, d, span, left, right, span_left, span_right; + double *restrict D; + sv_tables_t sv; + + memset(&sv, 0, sizeof(sv)); + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + ret = sv_tables_init(&sv, self->tables->nodes.num_rows + 1); + if (ret != 0) { + goto out; + } + + if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { + ret = TSK_ERR_TIME_UNCALIBRATED; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + span = span_right - span_left; + sv_tables_build(&sv, &tree); + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + w = sv_tables_mrca(&sv, u, v); + if (w != TSK_NULL) { + u_root = w; + v_root = w; + } else { + /* Slow path - only happens for nodes in disconnected + * subtrees in a tree with multiple roots */ + u_root = tsk_tree_get_node_root(&tree, u); + v_root = tsk_tree_get_node_root(&tree, v); + } + tu = nodes_time[u_root] - nodes_time[u]; + tv = nodes_time[v_root] - nodes_time[v]; + d = (tu + tv) * span; + D[j * n + k] += d; + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + sv_tables_free(&sv); + return ret; +} + +static tsk_size_t +count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, + const double *restrict time, const tsk_size_t *restrict mutations_per_node) +{ + double tu, tv; + tsk_size_t count = 0; + + tu = time[u]; + tv = time[v]; + while (u != v) { + if (tu < tv) { + count += mutations_per_node[u]; + u = parent[u]; + if (u == TSK_NULL) { + break; + } + tu = time[u]; + } else { + count += mutations_per_node[v]; + v = parent[v]; + if (v == TSK_NULL) { + break; + } + tv = time[v]; + } + } + if (u != v) { + while (u != TSK_NULL) { + count += mutations_per_node[u]; + u = parent[u]; + } + while (v != TSK_NULL) { + count += mutations_per_node[v]; + v = parent[v]; + } + } + return count; +} + +static int +tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t TSK_UNUSED(options), + double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const tsk_size_t n = num_samples; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const double *restrict nodes_time = self->tables->nodes.time; + tsk_size_t i, j, k, tree_site, tree_mut; + tsk_site_t site; + tsk_mutation_t mut; + tsk_id_t u, v; + double left, right, span_left, span_right; + double *restrict D; + tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node)); + + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + if (mutations_per_node == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + + /* NOTE: we could avoid this full memset across all nodes by doing + * the same loops again and decrementing at the end of the main + * tree-loop. It's probably not worth it though, because of the + * overwhelming O(n^2) below */ + tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node)); + for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { + site = tree.sites[tree_site]; + if (span_left <= site.position && site.position < span_right) { + for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { + mut = site.mutations[tree_mut]; + mutations_per_node[mut.node]++; + } + } + } + + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + D[j * n + k] += (double) count_mutations_on_path( + u, v, tree.parent, nodes_time, mutations_per_node); + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + tsk_safe_free(mutations_per_node); + return ret; +} + +static void +fill_lower_triangle( + double *restrict result, const tsk_size_t n, const tsk_size_t num_windows) +{ + tsk_size_t i, j, k; + double *restrict D; + + /* TODO there's probably a better striding pattern that could be used here */ + for (i = 0; i < num_windows; i++) { + D = result + i * n * n; + for (j = 0; j < n; j++) { + for (k = j + 1; k < n; k++) { + D[k * n + j] = D[j * n + k]; + } + } + } +} + +int +tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result) +{ + int ret = 0; + const tsk_id_t *samples = self->samples; + tsk_size_t n = self->num_samples; + const double default_windows[] = { 0, self->tables->sequence_length }; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + bool stat_node = !!(options & TSK_STAT_NODE); + + if (stat_node) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } + /* If no mode is specified, we default to site mode */ + if (!(stat_site || stat_branch)) { + stat_site = true; + } + /* It's an error to specify more than one mode */ + if (stat_site + stat_branch > 1) { + ret = TSK_ERR_MULTIPLE_STAT_MODES; + goto out; + } + + if (options & TSK_STAT_POLARISED) { + ret = TSK_ERR_STAT_POLARISED_UNSUPPORTED; + goto out; + } + if (options & TSK_STAT_SPAN_NORMALISE) { + ret = TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED; + goto out; + } + + if (windows == NULL) { + num_windows = 1; + windows = default_windows; + } else { + ret = tsk_treeseq_check_windows(self, num_windows, windows, 0); + if (ret != 0) { + goto out; + } + } + + if (samples_in != NULL) { + samples = samples_in; + n = num_samples; + ret = tsk_treeseq_check_node_bounds(self, n, samples); + if (ret != 0) { + goto out; + } + } + + tsk_memset(result, 0, num_windows * n * n * sizeof(*result)); + + if (stat_branch) { + ret = tsk_treeseq_divergence_matrix_branch( + self, n, samples, num_windows, windows, options, result); + } else { + tsk_bug_assert(stat_site); + ret = tsk_treeseq_divergence_matrix_site( + self, n, samples, num_windows, windows, options, result); + } + if (ret != 0) { + goto out; + } + fill_lower_triangle(result, n, num_windows); + +out: + return ret; +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index efe9980077..10c820e1c0 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1003,6 +1003,10 @@ int tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result); + /****************************************************************************/ /* Tree */ /****************************************************************************/ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index dea3c03fd9..8d42b50afc 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9638,6 +9638,78 @@ TreeSequence_f4(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_stat_method(self, args, kwds, 4, tsk_treeseq_f4); } +static PyObject * +TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "windows", "samples", "mode", NULL }; + PyArrayObject *result_array = NULL; + PyObject *windows = NULL; + PyObject *py_samples = Py_None; + char *mode = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *samples_array = NULL; + tsk_flags_t options = 0; + npy_intp *shape, dims[3]; + tsk_size_t num_samples, num_windows; + tsk_id_t *samples = NULL; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|Os", kwlist, &windows, &py_samples, &mode)) { + goto out; + } + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + if (py_samples != Py_None) { + samples_array = (PyArrayObject *) PyArray_FROMANY( + py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + samples = PyArray_DATA(samples_array); + num_samples = (tsk_size_t) shape[0]; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + dims[0] = num_windows; + dims[1] = num_samples; + dims[2] = num_samples; + result_array = (PyArrayObject *) PyArray_SimpleNew(3, dims, NPY_FLOAT64); + if (result_array == NULL) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = tsk_treeseq_divergence_matrix( + self->tree_sequence, + num_samples, samples, + num_windows, PyArray_DATA(windows_array), + options, PyArray_DATA(result_array)); + Py_END_ALLOW_THREADS + // clang-format on + /* Clang-format insists on doing this in spite of the "off" instruction above */ + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(result_array); + Py_XDECREF(windows_array); + Py_XDECREF(samples_array); + return ret; +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -10346,6 +10418,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_f4, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the f4 statistic." }, + { .ml_name = "divergence_matrix", + .ml_meth = (PyCFunction) TreeSequence_divergence_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the pairwise divergence matrix." }, { .ml_name = "split_edges", .ml_meth = (PyCFunction) TreeSequence_split_edges, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py new file mode 100644 index 0000000000..acb2403d41 --- /dev/null +++ b/python/tests/test_divmat.py @@ -0,0 +1,1064 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for divergence matrix based pairwise stats +""" +import collections + +import msprime +import numpy as np +import pytest + +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + +DIVMAT_MODES = ["branch", "site"] + +# NOTE: this implementation of Schieber-Vishkin algorithm is done like +# this so it's easy to run with numba. It would be more naturally +# packaged as a class. We don't actually use numba here, but it's +# handy to have a version of the SV code lying around that can be +# run directly with numba. + + +def sv_tables_init(parent_array): + n = 1 + parent_array.shape[0] + + LAMBDA = 0 + # Triply-linked tree. FIXME we shouldn't need to build this as it's + # available already in tskit + child = np.zeros(n, dtype=np.int32) + parent = np.zeros(n, dtype=np.int32) + sib = np.zeros(n, dtype=np.int32) + + for j in range(n - 1): + u = j + 1 + v = parent_array[j] + 1 + sib[u] = child[v] + child[v] = u + parent[u] = v + + lambd = np.zeros(n, dtype=np.int32) + pi = np.zeros(n, dtype=np.int32) + tau = np.zeros(n, dtype=np.int32) + beta = np.zeros(n, dtype=np.int32) + alpha = np.zeros(n, dtype=np.int32) + + p = child[LAMBDA] + n = 0 + lambd[0] = -1 + while p != LAMBDA: + while True: + n += 1 + pi[p] = n + tau[n] = LAMBDA + lambd[n] = 1 + lambd[n >> 1] + if child[p] != LAMBDA: + p = child[p] + else: + break + beta[p] = n + while True: + tau[beta[p]] = parent[p] + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p != LAMBDA: + h = lambd[n & -pi[p]] + beta[p] = ((n >> h) | 1) << h + else: + break + + # Begin the second traversal + lambd[0] = lambd[n] + pi[LAMBDA] = 0 + beta[LAMBDA] = 0 + alpha[LAMBDA] = 0 + p = child[LAMBDA] + while p != LAMBDA: + while True: + a = alpha[parent[p]] | (beta[p] & -beta[p]) + alpha[p] = a + if child[p] != LAMBDA: + p = child[p] + else: + break + while True: + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p == LAMBDA: + break + + return lambd, pi, tau, beta, alpha + + +def _sv_mrca(x, y, lambd, pi, tau, beta, alpha): + if beta[x] <= beta[y]: + h = lambd[beta[y] & -beta[x]] + else: + h = lambd[beta[x] & -beta[y]] + k = alpha[x] & alpha[y] & -(1 << h) + h = lambd[k & -k] + j = ((beta[x] >> h) | 1) << h + if j == beta[x]: + xhat = x + else: + ell = lambd[alpha[x] & ((1 << h) - 1)] + xhat = tau[((beta[x] >> ell) | 1) << ell] + if j == beta[y]: + yhat = y + else: + ell = lambd[alpha[y] & ((1 << h) - 1)] + yhat = tau[((beta[y] >> ell) | 1) << ell] + if pi[xhat] <= pi[yhat]: + z = xhat + else: + z = yhat + return z + + +def sv_mrca(x, y, lambd, pi, tau, beta, alpha): + # Convert to 1-based indexes + return _sv_mrca(x + 1, y + 1, lambd, pi, tau, beta, alpha) - 1 + + +def local_root(tree, u): + while tree.parent(u) != tskit.NULL: + u = tree.parent(u) + return u + + +def branch_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + # print(f"WINDOW {i} [{left}, {right})") + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + span = span_right - span_left + # print(f"\ttree {tree.interval} [{span_left}, {span_right})") + tables = sv_tables_init(tree.parent_array) + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = sv_mrca(u, v, *tables) + assert w == tree.mrca(u, v) + if w != tskit.NULL: + tu = ts.nodes_time[w] - ts.nodes_time[u] + tv = ts.nodes_time[w] - ts.nodes_time[v] + else: + tu = ts.nodes_time[local_root(tree, u)] - ts.nodes_time[u] + tv = ts.nodes_time[local_root(tree, v)] - ts.nodes_time[v] + d = (tu + tv) * span + D[i, j, k] += d + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def divergence_matrix(ts, windows=None, samples=None, mode="site"): + assert mode in ["site", "branch"] + if mode == "site": + return site_divergence_matrix(ts, samples=samples, windows=windows) + else: + return branch_divergence_matrix(ts, samples=samples, windows=windows) + + +def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="site"): + samples = ts.samples() if samples is None else samples + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else list(windows) + num_windows = len(windows) - 1 + + if len(samples) == 0: + # FIXME: the code general stat code doesn't seem to handle zero samples + # case, need to identify MWE and file issue. + if windows_specified: + return np.zeros(shape=(num_windows, 0, 0)) + else: + return np.zeros(shape=(0, 0)) + + # Make sure that all the specified samples have the sample flag set, otherwise + # the library code will complain + tables = ts.dump_tables() + flags = tables.nodes.flags + # NOTE: this is a shortcut, setting all flags unconditionally to zero, so don't + # use this tree sequence outside this method. + flags[:] = 0 + flags[samples] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + + # FIXME We have to go through this annoying rigmarole because windows must start and + # end with 0 and L. We should relax this requirement to just making the windows + # contiguous, so that we just look at specific sections of the genome. + drop = [] + if windows[0] != 0: + windows = [0] + windows + drop.append(0) + if windows[-1] != ts.sequence_length: + windows.append(ts.sequence_length) + drop.append(-1) + + n = len(samples) + sample_sets = [[u] for u in samples] + indexes = [(i, j) for i in range(n) for j in range(n)] + X = ts.divergence( + sample_sets, + indexes=indexes, + mode=mode, + span_normalise=False, + windows=windows, + ) + keep = np.ones(len(windows) - 1, dtype=bool) + keep[drop] = False + X = X[keep] + out = X.reshape((X.shape[0], n, n)) + for D in out: + np.fill_diagonal(D, 0) + if not windows_specified: + out = out[0] + return out + + +def rootward_path(tree, u, v): + while u != v: + yield u + u = tree.parent(u) + + +def site_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + mutations_per_node = collections.Counter() + for site in tree.sites(): + if span_left <= site.position < span_right: + for mutation in site.mutations: + mutations_per_node[mutation.node] += 1 + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = tree.mrca(u, v) + if w != tskit.NULL: + wu = w + wv = w + else: + wu = local_root(tree, u) + wv = local_root(tree, v) + du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) + dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) + # NOTE: we're just accumulating the raw mutation counts, not + # multiplying by span + D[i, j, k] += du + dv + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def check_divmat( + ts, + *, + windows=None, + samples=None, + verbosity=0, + compare_stats_api=True, + compare_lib=True, + mode="site", +): + np.set_printoptions(linewidth=500, precision=4) + # print(ts.draw_text()) + if verbosity > 1: + print(ts.draw_text()) + + D1 = divergence_matrix(ts, windows=windows, samples=samples, mode=mode) + if compare_stats_api: + # Somethings like duplicate samples aren't worth hacking around for in + # stats API. + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + # print("windows = ", windows) + # print(D1) + # print(D2) + np.testing.assert_allclose(D1, D2) + assert D1.shape == D2.shape + if compare_lib: + D3 = ts.divergence_matrix(windows=windows, samples=samples, mode=mode) + # print(D3) + assert D1.shape == D3.shape + np.testing.assert_allclose(D1, D3) + return D1 + + +class TestExamplesWithAnswer: + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples(self, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + D = check_divmat(ts, samples=[], mode="site") + assert D.shape == (0, 0) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples_windows(self, num_windows, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D = check_divmat(ts, samples=[], windows=windows, mode="site") + assert D.shape == (num_windows, 0, 0) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_sites_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts, m) + D1 = check_divmat(ts, mode="site") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_mutations_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_mutations(ts, m) + # The stats API will produce a different value here, because + # we're just counting up the mutations and not reasoning about + # the state of samples at all. + D1 = check_divmat(ts, mode="site", compare_stats_api=False) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("L", [0.1, 1, 2, 100]) + def test_single_tree_sequence_length(self, L): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=L).tree_sequence + D1 = check_divmat(ts, mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, L * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_gap_at_end(self, num_windows, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 + # 0 1 2 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + tables = ts.dump_tables() + tables.sequence_length = 2 + ts = tables.tree_sequence() + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D1 = check_divmat(ts, windows=windows, mode=mode) + D1 = np.sum(D1, axis=0) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_subset_permuted_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[1, 2, 0], mode=mode) + D2 = np.array( + [ + [0.0, 4.0, 2.0], + [4.0, 0.0, 4.0], + [2.0, 4.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_mixed_non_sample_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 5], mode=mode) + D2 = np.array( + [ + [0.0, 3.0], + [3.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_duplicate_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 0, 1], compare_stats_api=False, mode=mode) + D2 = np.array( + [ + [0.0, 0.0, 2.0], + [0.0, 0.0, 2.0], + [2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_multiroot(self, mode): + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + ts = ts.decapitate(1) + D1 = check_divmat(ts, mode=mode) + D2 = np.array( + [ + [0.0, 2.0, 2.0, 2.0], + [2.0, 0.0, 2.0, 2.0], + [2.0, 2.0, 0.0, 2.0], + [2.0, 2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize( + ["left", "right"], [(0, 10), (1, 3), (3.25, 3.75), (5, 10)] + ) + def test_single_tree_interval(self, left, right): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + D1 = check_divmat(ts, windows=[left, right], mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1[0], (right - left) * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5, 11]) + def test_single_tree_equal_windows(self, num_windows): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + x = ts.sequence_length / num_windows + # print(windows) + D1 = check_divmat(ts, windows=windows, mode="branch") + assert D1.shape == (num_windows, 4, 4) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + for D in D1: + np.testing.assert_array_almost_equal(D, x * D2) + + @pytest.mark.parametrize("n", [2, 3, 5]) + def test_single_tree_no_sites(self, n): + ts = tskit.Tree.generate_balanced(n, span=10).tree_sequence + D = check_divmat(ts, mode="site") + np.testing.assert_array_equal(D, np.zeros((n, n))) + + +class TestExamples: + @pytest.mark.parametrize( + "interval", [(0, 26), (1, 3), (3.25, 13.75), (5, 10), (25.5, 26)] + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_interval(self, interval, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + check_divmat(ts, windows=interval, mode=mode) + + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows(self, windows, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + D = check_divmat(ts, windows=windows, mode=mode) + assert D.shape == (len(windows) - 1, 4, 4) + + @pytest.mark.parametrize("num_windows", [1, 5, 28]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows_gap_at_end(self, num_windows, mode): + tables = tsutil.all_trees_ts(4).dump_tables() + tables.sequence_length = 30 + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + assert ts.last().num_roots == 4 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5]) + @pytest.mark.parametrize("seed", range(1, 4)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_small_sims(self, n, seed, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + sequence_length=1000, + recombination_rate=0.01, + random_seed=seed, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, rate=0.1, discrete_genome=False, random_seed=seed + ) + assert ts.num_mutations > 1 + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_windows", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_sims_windows(self, n, num_windows, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=79234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, + rate=0.01, + discrete_genome=False, + random_seed=1234, + ) + assert ts.num_mutations >= 2 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_balanced_tree(self, n, mode): + ts = tskit.Tree.generate_balanced(n).tree_sequence + ts = tsutil.insert_branch_sites(ts) + # print(ts.draw_text()) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_internal_sample(self, mode): + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[3] = 0 + flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("seed", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_one_internal_sample_sims(self, seed, mode): + ts = msprime.sim_ancestry( + 10, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=seed, + ) + t = ts.dump_tables() + # Add a new sample directly below another sample + u = t.nodes.add_row(time=-1, flags=tskit.NODE_IS_SAMPLE) + t.edges.add_row(parent=0, child=u, left=0, right=ts.sequence_length) + t.sort() + t.build_index() + ts = t.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_missing_flanks(self, mode): + ts = msprime.sim_ancestry( + 20, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = ts.keep_intervals([[20, 80]]) + assert ts.first().interval == (0, 20) + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_samples(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in ts1.samples(): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_all(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in range(ts1.num_nodes): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_disconnected_non_sample_topology(self, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(5).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + # Add an extra bit of disconnected non-sample topology + u = tables.nodes.add_row(time=0) + v = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=v, child=u) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + +class TestSuiteExamples: + """ + Compare the stats API method vs the library implementation for the + suite test examples. Some of these examples are too large to run the + Python code above on. + """ + + def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): + D1 = ts.divergence_matrix( + windows=windows, + samples=samples, + num_threads=num_threads, + mode=mode, + ) + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + assert D1.shape == D2.shape + if mode == "branch": + # If we have missing data then parts of the divmat are defined to be zero, + # so relative tolerances aren't useful. Because the stats API + # method necessarily involves subtracting away all of the previous + # values for an empty tree, there is a degree of numerical imprecision + # here. This value for atol is what is needed to get the tests to + # pass in practise. + has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees()) + atol = 1e-12 if has_missing_data else 0 + np.testing.assert_allclose(D1, D2, atol=atol) + else: + assert mode == "site" + if np.any(ts.mutations_parent != tskit.NULL): + # The stats API computes something slightly different when we have + # recurrent mutations, so fall back to the naive version. + D2 = site_divergence_matrix(ts, windows=windows, samples=samples) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_defaults(self, ts, mode): + self.check(ts, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_subset_samples(self, ts, mode): + n = min(ts.num_samples, 2) + self.check(ts, samples=ts.samples()[:n], mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=13) + self.check(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_no_windows(self, ts, mode): + self.check(ts, num_threads=5, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=11) + self.check(ts, num_threads=5, windows=windows, mode=mode) + + +class TestThreadsNoWindows: + def check(self, ts, num_threads, samples=None, mode=None): + D1 = ts.divergence_matrix(num_threads=0, samples=samples, mode=mode) + D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, 2, samples, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, n, num_threads, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + self.check(ts, num_threads, mode=mode) + + +class TestThreadsWindows: + def check(self, ts, num_threads, *, windows, samples=None, mode=None): + D1 = ts.divergence_matrix( + num_threads=0, windows=windows, samples=samples, mode=mode + ) + D2 = ts.divergence_matrix( + num_threads=num_threads, windows=windows, samples=samples, mode=mode + ) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, windows, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, windows=windows, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + (None,), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, windows, mode): + ts = tsutil.all_trees_ts(4) + self.check(ts, 2, windows=windows, samples=samples, mode=mode) + + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 100],), + ([0, 50, 75, 95, 100],), + ([50, 75, 95, 100],), + ([0, 50, 75, 95],), + (list(range(100)),), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, num_threads, windows, mode): + ts = msprime.sim_ancestry( + 15, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1234) + assert ts.num_mutations > 10 + self.check(ts, num_threads, windows=windows, mode=mode) + + +# NOTE these are tests that are for more general functionality that might +# get applied across many different functions, and so probably should be +# tested in another file. For now they're only used by divmat, so we can +# keep them here for simplificity. +class TestChunkByTree: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 26]]), + (2, [[0, 13], [13, 26]]), + (3, [[0, 9], [9, 18], [18, 26]]), + (4, [[0, 7], [7, 14], [14, 20], [20, 26]]), + (5, [[0, 6], [6, 11], [11, 16], [16, 21], [21, 26]]), + ], + ) + def test_all_trees_ts_26(self, num_chunks, expected): + ts = tsutil.all_trees_ts(4) + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4(self, num_chunks, expected): + ts = tsutil.all_trees_ts(3) + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("span", [1, 2, 5, 0.3]) + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4_trees_span(self, span, num_chunks, expected): + tables = tsutil.all_trees_ts(3).dump_tables() + tables.edges.left *= span + tables.edges.right *= span + tables.sequence_length *= span + ts = tables.tree_sequence() + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, np.array(expected) * span) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_empty_ts(self, num_chunks): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, 1]]) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_single_tree(self, num_chunks): + L = 10 + ts = tskit.Tree.generate_balanced(2, span=L).tree_sequence + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, L]]) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + ts = tskit.Tree.generate_balanced(2).tree_sequence + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + ts._chunk_sequence_by_tree(num_chunks) + + +class TestChunkWindows: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["windows", "num_chunks", "expected"], + [ + ([0, 10], 1, [[0, 10]]), + ([0, 10], 2, [[0, 10]]), + ([0, 5, 10], 2, [[0, 5], [5, 10]]), + ([0, 5, 6, 10], 2, [[0, 5, 6], [6, 10]]), + ([0, 5, 6, 10], 3, [[0, 5], [5, 6], [6, 10]]), + ], + ) + def test_examples(self, windows, num_chunks, expected): + actual = tskit.TreeSequence._chunk_windows(windows, num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + tskit.TreeSequence._chunk_windows([0, 1], num_chunks) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index ce225f1dd7..0529dda001 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -228,14 +228,14 @@ def get_gap_examples(): assert len(t.parent_dict) == 0 found = True assert found - ret.append((f"gap {x}", ts)) + ret.append((f"gap_{x}", ts)) # Give an example with a gap at the end. ts = msprime.simulate(10, random_seed=5, recombination_rate=1) tables = get_table_collection_copy(ts.dump_tables(), 2) tables.sites.clear() tables.mutations.clear() insert_uniform_mutations(tables, 100, list(ts.samples())) - ret.append(("gap at end", tables.tree_sequence())) + ret.append(("gap_at_end", tables.tree_sequence())) return ret diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 70ef08143c..530ec7223f 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1529,6 +1529,26 @@ def test_kc_distance(self): x2 = ts2.get_kc_distance(ts1, lambda_) assert x1 == x2 + def test_divergence_matrix(self): + n = 10 + ts = self.get_example_tree_sequence(n, random_seed=12) + D = ts.divergence_matrix([0, ts.get_sequence_length()]) + assert D.shape == (1, n, n) + D = ts.divergence_matrix([0, ts.get_sequence_length()], samples=[0, 1]) + assert D.shape == (1, 2, 2) + with pytest.raises(TypeError): + ts.divergence_matrix(windoze=[0, 1]) + with pytest.raises(ValueError, match="at least 2"): + ts.divergence_matrix(windows=[0]) + with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): + ts.divergence_matrix(windows=[-1, 0, 1]) + with pytest.raises(ValueError): + ts.divergence_matrix(windows=[0, 1], samples="sdf") + with pytest.raises(ValueError, match="Unrecognised stats mode"): + ts.divergence_matrix(windows=[0, 1], mode="sdf") + with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"): + ts.divergence_matrix(windows=[0, 1], mode="node") + def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): tables = _tskit.TableCollection(sequence_length=ts.get_sequence_length()) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 6f3b080ce5..34334e9be0 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -81,6 +81,8 @@ def insert_branch_mutations(ts, mutations_per_branch=1): Returns a copy of the specified tree sequence with a mutation on every branch in every tree. """ + if mutations_per_branch == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() @@ -146,23 +148,26 @@ def insert_discrete_time_mutations(ts, num_times=4, num_sites=10): return tables.tree_sequence() -def insert_branch_sites(ts): +def insert_branch_sites(ts, m=1): """ - Returns a copy of the specified tree sequence with a site on every branch + Returns a copy of the specified tree sequence with m sites on every branch of every tree. """ + if m == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() for tree in ts.trees(): left, right = tree.interval - delta = (right - left) / len(list(tree.nodes())) + delta = (right - left) / (m * len(list(tree.nodes()))) x = left for u in tree.nodes(): if tree.parent(u) != tskit.NULL: - site = tables.sites.add_row(position=x, ancestral_state="0") - tables.mutations.add_row(site=site, node=u, derived_state="1") - x += delta + for _ in range(m): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, node=u, derived_state="1") + x += delta add_provenance(tables.provenances, "insert_branch_sites") return tables.tree_sequence() @@ -1774,7 +1779,6 @@ def update_counts(edge, left, sign): def genealogical_nearest_neighbours(ts, focal, reference_sets): - reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1 for k, reference_set in enumerate(reference_sets): for u in reference_set: diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9ccae3488d..1c26956494 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7695,8 +7695,80 @@ def divergence( span_normalise=span_normalise, ) - # JK: commenting this out for now to get the other methods well tested. - # Issue: https://github.com/tskit-dev/tskit/issues/201 + ############################################ + # Pairwise sample x sample statistics + ############################################ + + def _chunk_sequence_by_tree(self, num_chunks): + """ + Return list of (left, right) genome interval tuples that contain + approximately equal numbers of trees as a 2D numpy array. A + maximum of self.num_trees single-tree intervals can be returned. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(self.num_trees, num_chunks) + breakpoints = self.breakpoints(as_array=True)[:-1] + splits = np.array_split(breakpoints, num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunks.append((splits[j][0], splits[j + 1][0])) + chunks.append((splits[-1][0], self.sequence_length)) + return chunks + + @staticmethod + def _chunk_windows(windows, num_chunks): + """ + Returns a list of (at most) num_chunks windows, which represent splitting + up the specified list of windows into roughly equal work. + + Currently this is implemented by just splitting up into roughly equal + numbers of windows in each chunk. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(len(windows) - 1, num_chunks) + splits = np.array_split(windows[:-1], num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunk = np.append(splits[j], splits[j + 1][0]) + chunks.append(chunk) + chunk = np.append(splits[-1], windows[-1]) + chunks.append(chunk) + return chunks + + def _parallelise_divmat_by_tree(self, num_threads, **kwargs): + """ + No windows were specified, so we can chunk up the whole genome by + tree, and do a simple sum of the results. + """ + + def worker(interval): + return self._ll_tree_sequence.divergence_matrix(interval, **kwargs) + + work = self._chunk_sequence_by_tree(num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool: + results = pool.map(worker, work) + return sum(results) + + def _parallelise_divmat_by_window(self, windows, num_threads, **kwargs): + """ + We assume we have a number of windows that's >= to the number + of threads available, and let each thread have a chunk of the + windows. There will definitely cases where this leads to + pathological behaviour, so we may need a more sophisticated + strategy at some point. + """ + + def worker(sub_windows): + return self._ll_tree_sequence.divergence_matrix(sub_windows, **kwargs) + + work = self._chunk_windows(windows, num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, sub_windows) for sub_windows in work] + concurrent.futures.wait(futures) + return np.vstack([future.result() for future in futures]) + # def divergence_matrix(self, sample_sets, windows=None, mode="site"): # """ # Finds the mean divergence between pairs of samples from each set of @@ -7730,6 +7802,36 @@ def divergence( # A[w, i, j] = A[w, j, i] = x[w][k] # k += 1 # return A + # NOTE: see older definition of divmat here, which may be useful when documenting + # this function. See https://github.com/tskit-dev/tskit/issues/2781 + def divergence_matrix( + self, *, windows=None, samples=None, num_threads=0, mode=None + ): + windows_specified = windows is not None + windows = [0, self.sequence_length] if windows is None else windows + + mode = "site" if mode is None else mode + + # NOTE: maybe we want to use a different default for num_threads here, just + # following the approach in GNN + if num_threads <= 0: + D = self._ll_tree_sequence.divergence_matrix( + windows, samples=samples, mode=mode + ) + else: + if windows_specified: + D = self._parallelise_divmat_by_window( + windows, num_threads, samples=samples, mode=mode + ) + else: + D = self._parallelise_divmat_by_tree( + num_threads, samples=samples, mode=mode + ) + + if not windows_specified: + # Drop the windows dimension + D = D[0] + return D def genetic_relatedness( self, From 555913cbf830c5c680b5cb26e5f96b3c6ab17ad1 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 12 Jun 2023 12:44:34 +0100 Subject: [PATCH 56/84] Add asdict to Dataclass --- python/CHANGELOG.rst | 5 ++++- python/tests/test_tables.py | 8 ++++++++ python/tskit/util.py | 7 +++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 3a851c3e87..94ad670759 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,11 +1,14 @@ -------------------- -[0.5.X] - 2023-XX-XX +[0.5.6] - 2023-XX-XX -------------------- **Features** - Add ``TreeSequence.impute_unknown_mutations_time`` method to return an array of mutation times based on the times of associated nodes (:user:`duncanMR`, :pr:`2760`, :issue:`2758`) +- Add ``asdict`` to all dataclasses. These are returned when you access a row or + other tree sequence object. (:user:`benjeffery`, :pr:`2759`, :issue:`2719`) + -------------------- [0.5.5] - 2023-05-17 -------------------- diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index b6d207e2e0..265fee48b2 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -140,6 +140,14 @@ def table_5row(self, test_rows): table_5row.add_row(**row) return table_5row + def test_asdict(self, table, test_rows): + for table_row, test_row in zip(table, test_rows): + for k, v in table_row.asdict().items(): + if isinstance(v, np.ndarray): + assert np.array_equal(v, test_row[k]) + else: + assert v == test_row[k] + def test_max_rows_increment(self): for bad_value in [-1, -(2**10)]: with pytest.raises(ValueError): diff --git a/python/tskit/util.py b/python/tskit/util.py index 72f08499d6..28e9876b5a 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -47,6 +47,13 @@ def replace(self, **kwargs): """ return dataclasses.replace(self, **kwargs) + def asdict(self, **kwargs): + """ + Return a new dict which maps field names to their corresponding values + in this dataclass. + """ + return dataclasses.asdict(self, **kwargs) + def canonical_json(obj): """ From 9a093cf5f2992456e7febf96db69c888c9dd86e2 Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 15 Mar 2021 07:13:51 -0700 Subject: [PATCH 57/84] matmul plumbing --- c/tests/test_stats.c | 20 +++++++++ c/tskit/trees.c | 82 ++++++++++++++++++++++++++++++++++- c/tskit/trees.h | 11 +++++ python/_tskitmodule.c | 99 +++++++++++++++++++++++++++++++++++++++++++ python/tskit/trees.py | 83 ++++++++++++++++++++++++++++++++++++ 5 files changed, 294 insertions(+), 1 deletion(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 2d5bc97fd5..154c0b6296 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1504,6 +1504,24 @@ test_paper_ex_genetic_relatedness(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_genetic_relatedness_weighted(void) +{ + tsk_treeseq_t ts; + double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 }; + tsk_id_t indexes[] = { 0, 0, 0, 1 }; + double result[2]; + int ret; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); +} + static void test_paper_ex_genetic_relatedness_errors(void) { @@ -2108,6 +2126,8 @@ main(int argc, char **argv) { "test_paper_ex_genetic_relatedness_errors", test_paper_ex_genetic_relatedness_errors }, { "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness }, + { "test_paper_ex_genetic_relatedness_weighted", + test_paper_ex_genetic_relatedness_weighted }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index cd0ad36aa2..0c69edd354 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2071,6 +2071,11 @@ typedef struct { const tsk_id_t *set_indexes; } sample_count_stat_params_t; +typedef struct { + double *total_weights; + const tsk_id_t *index_tuples; +} indexed_weight_stat_params_t; + static int tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, @@ -3025,7 +3030,82 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample return ret; } -static int +genetic_relatedness_weighted_summary_func(size_t state_dim, const double *state, + size_t result_dim, double *result, void *params) +{ + indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + size_t k; + double meanx, ni, nj; + + meanx = state[state_dim - 1] / args.total_weights[state_dim - 1]; + ; + for (k = 0; k < result_dim; k++) { + i = args.index_tuples[2 * k]; + j = args.index_tuples[2 * k + 1]; + ni = args.total_weights[i]; + nj = args.total_weights[j]; + result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; + } + return 0; +} + +int +tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options) +{ + int ret = 0; + tsk_size_t num_samples = self->num_samples; + size_t j, k; + indexed_weight_stat_params_t args; + const double *row; + double *new_row; + double *total_weights = malloc((num_weights + 1) * sizeof(*total_weights)); + double *new_weights = malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); + + if (total_weights == NULL || new_weights == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + // add a column of ones to W + for (k = 0; k < num_samples; k++) { + row = GET_2D_ROW(weights, num_weights, k); + new_row = GET_2D_ROW(new_weights, num_weights + 1, k); + for (j = 0; j < num_weights; j++) { + new_row[j] = row[j]; + } + new_row[num_weights] = 1.0; + } + + /* TODO: sanity check indexes */ + + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights + 1; k++) { + total_weights[k] += row[k]; + } + } + + args.total_weights = total_weights; + args.index_tuples = index_tuples; + + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples, + genetic_relatedness_weighted_summary_func, &args, num_windows, windows, result, + options); + if (ret != 0) { + goto out; + } + +out: + tsk_safe_free(total_weights); + tsk_safe_free(new_weights); + return ret; +} + Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) { diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 10c820e1c0..b36a38c31f 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -943,6 +943,17 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei const double *weights, tsk_size_t num_covariates, const double *covariates, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +/* Two way weighted stats with covariates */ + +typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, + tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); + +int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options); + /* One way sample set stats */ typedef int one_way_sample_stat_method(const tsk_treeseq_t *self, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 8d42b50afc..ceaf45c6df 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9595,6 +9595,93 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd return ret; } +static PyObject * +TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, + PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise", + "polarised", NULL }; + PyObject *weights = NULL; + PyObject *indexes = NULL; + PyObject *windows = NULL; + PyArrayObject *weights_array = NULL; + PyArrayObject *indexes_array = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *result_array = NULL; + tsk_size_t num_windows, num_index_tuples; + npy_intp *w_shape, *shape; + tsk_flags_t options = 0; + char *mode = NULL; + int span_normalise = true; + int polarised = false; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes, + &windows, &mode, &span_normalise, &polarised)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (span_normalise) { + options |= TSK_STAT_SPAN_NORMALISE; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + weights_array = (PyArrayObject *) PyArray_FROMANY( + weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY); + if (weights_array == NULL) { + goto out; + } + w_shape = PyArray_DIMS(weights_array); + if (w_shape[0] != tsk_treeseq_get_num_samples(self->tree_sequence)) { + PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); + goto out; + } + + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_index_tuples = shape[0]; + + result_array = TreeSequence_allocate_results_array( + self, options, num_windows, num_index_tuples); + if (result_array == NULL) { + goto out; + } + err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), + num_index_tuples, PyArray_DATA(indexes_array), num_windows, + PyArray_DATA(windows_array), PyArray_DATA(result_array), options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(weights_array); + Py_XDECREF(indexes_array); + Py_XDECREF(windows_array); + Py_XDECREF(result_array); + return ret; +} + static PyObject * TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -9608,6 +9695,14 @@ TreeSequence_genetic_relatedness(TreeSequence *self, PyObject *args, PyObject *k self, args, kwds, 2, tsk_treeseq_genetic_relatedness); } +static PyObject * +TreeSequence_genetic_relatedness_weighted( + TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_weighted_stat_method( + self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted); +} + static PyObject * TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -10394,6 +10489,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes genetic relatedness between sample sets." }, + { .ml_name = "genetic_relatedness_weighted", + .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes genetic relatedness between weighted sums of samples." }, { .ml_name = "Y1", .ml_meth = (PyCFunction) TreeSequence_Y1, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 1c26956494..356f1fcfe4 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7558,6 +7558,47 @@ def __k_way_sample_set_stat( stat = stat[()] return stat + def __k_way_weighted_stat( + self, + ll_method, + k, + W, + indexes=None, + windows=None, + mode=None, + span_normalise=True, + polarised=False, + ): + if indexes is None: + if W.shape[1] != k: + raise ValueError( + "Must specify indexes if there are not exactly {} columsn " + "in W.".format(k) + ) + indexes = np.arange(k, dtype=np.int32) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + stat = self.__run_windowed_stat( + windows, + ll_method, + W, + indexes, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + if drop_dimension: + stat = stat.reshape(stat.shape[:-1]) + return stat + ############################################ # Statistics definitions ############################################ @@ -7945,6 +7986,48 @@ def genetic_relatedness( return out + def genetic_relatedness_weighted( + self, + W, + indexes=None, + windows=None, + mode="site", + span_normalise=True, + polarised=False, + ): + r""" + Computes weighted genetic relatedness: if the k-th pair of indices is (i, j) + then the k-th column of output will be + :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, + where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the + {meth}`.genetic_relatedness` between sample i and sample j. + + :param numpy.ndarray W: An array of values with one row for each sample and one + column for each set of weights. + :param list indexes: A list of 2-tuples, or None. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param str mode: A string giving the "type" of the statistic to be computed + (defaults to "site"). + :param bool span_normalise: Whether to divide the result by the span of the + window (defaults to True). + :return: A ndarray with shape equal to (num windows, num statistics). + """ + if W.shape[0] != self.num_samples: + raise ValueError( + "First trait dimension must be equal to number of samples." + ) + return self.__k_way_weighted_stat( + self._ll_tree_sequence.genetic_relatedness_weighted, + 2, + W, + indexes=indexes, + windows=windows, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W`` From d013c475674931d88da7318315a48b86aa226b91 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 4 Jun 2021 11:01:24 +0100 Subject: [PATCH 58/84] Add tests for weighted genetic relatedness. --- python/tests/test_lowlevel.py | 66 +++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 530ec7223f..628465d9c6 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2116,6 +2116,60 @@ def f(indexes): f(bad_dim) +class TwoWayWeightedStatsMixin(StatsInterfaceMixin): + """ + Tests for the weighted two way sample stats. + """ + + def get_example(self): + ts, method = self.get_method() + params = { + "weights": np.zeros((ts.get_num_samples(), 2)) + 0.5, + "indexes": [[0, 1]], + "windows": [0, ts.get_sequence_length()], + } + return ts, method, params + + def test_basic_example(self): + ts, method = self.get_method() + div = method( + np.zeros((ts.get_num_samples(), 1)) + 0.5, + [[0, 1]], + windows=[0, ts.get_sequence_length()], + ) + assert div.shape == (1, 1) + + def test_bad_weights(self): + ts, f, params = self.get_example() + del params["weights"] + n = ts.get_num_samples() + + for bad_weight_shape in [(n - 1, 1), (n + 1, 1), (0, 3)]: + with pytest.raises(ValueError): + f(weights=np.ones(bad_weight_shape), **params) + + def test_output_dims(self): + ts, method, params = self.get_example() + weights = params.pop("weights") + params["windows"] = [0, ts.get_sequence_length()] + + for mode in ["site", "branch"]: + out = method(weights[:, [0]], mode=mode, **params) + assert out.shape == (1, 1) + out = method(weights, mode=mode, **params) + assert out.shape == (1, 1) + out = method(weights[:, [0, 0, 0]], mode=mode, **params) + assert out.shape == (1, 1) + mode = "node" + N = ts.get_num_nodes() + out = method(weights[:, [0]], mode=mode, **params) + assert out.shape == (1, N, 1) + out = method(weights, mode=mode, **params) + assert out.shape == (1, N, 1) + out = method(weights[:, [0, 0, 0]], mode=mode, **params) + assert out.shape == (1, N, 1) + + class ThreeWaySampleStatsMixin(SampleSetMixin): """ Tests for the two way sample stats. @@ -2302,6 +2356,12 @@ def get_method(self): return ts, ts.f2 +class TestGeneticRelatedness(LowLevelTestCase, TwoWaySampleStatsMixin): + def get_method(self): + ts = self.get_example_tree_sequence() + return ts, ts.genetic_relatedness + + class TestY3(LowLevelTestCase, ThreeWaySampleStatsMixin): def get_method(self): ts = self.get_example_tree_sequence() @@ -2320,6 +2380,12 @@ def get_method(self): return ts, ts.f4 +class TestWeightedGeneticRelatedness(LowLevelTestCase, TwoWayWeightedStatsMixin): + def get_method(self): + ts = self.get_example_tree_sequence() + return ts, ts.genetic_relatedness_weighted + + class TestGeneralStatsInterface(LowLevelTestCase, StatsInterfaceMixin): """ Tests for the general stats interface. From 6912a466d1e5423be027a3d214bc77ddb118b16e Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 4 Jun 2021 12:04:58 +0100 Subject: [PATCH 59/84] Fix memory bug. Initialise total_weights to 0 and avoid an extra loop through arrays. Fix compile error --- c/tskit/trees.c | 40 +++++++++++++++-------------------- python/_tskitmodule.c | 2 +- python/tests/test_lowlevel.py | 23 ++++++++++++++++++++ 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 0c69edd354..b5cb654a7d 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3030,17 +3030,17 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample return ret; } -genetic_relatedness_weighted_summary_func(size_t state_dim, const double *state, - size_t result_dim, double *result, void *params) +static int +genetic_relatedness_weighted_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t result_dim, double *result, void *params) { indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; const double *x = state; tsk_id_t i, j; - size_t k; + tsk_size_t k; double meanx, ni, nj; meanx = state[state_dim - 1] / args.total_weights[state_dim - 1]; - ; for (k = 0; k < result_dim; k++) { i = args.index_tuples[2 * k]; j = args.index_tuples[2 * k + 1]; @@ -3063,39 +3063,32 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, indexed_weight_stat_params_t args; const double *row; double *new_row; - double *total_weights = malloc((num_weights + 1) * sizeof(*total_weights)); - double *new_weights = malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); + double *total_weights = tsk_calloc((num_weights + 1), sizeof(*total_weights)); + double *new_weights + = tsk_malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); if (total_weights == NULL || new_weights == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - // add a column of ones to W - for (k = 0; k < num_samples; k++) { - row = GET_2D_ROW(weights, num_weights, k); - new_row = GET_2D_ROW(new_weights, num_weights + 1, k); - for (j = 0; j < num_weights; j++) { - new_row[j] = row[j]; - } - new_row[num_weights] = 1.0; - } - - /* TODO: sanity check indexes */ - + // Add a column of ones to W for (j = 0; j < num_samples; j++) { - row = GET_2D_ROW(new_weights, num_weights + 1, j); - for (k = 0; k < num_weights + 1; k++) { + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k]; total_weights[k] += row[k]; } + new_row[num_weights] = 1.0; } + total_weights[num_weights] = (double) num_samples; args.total_weights = total_weights; args.index_tuples = index_tuples; - ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples, - genetic_relatedness_weighted_summary_func, &args, num_windows, windows, result, - options); + genetic_relatedness_weighted_summary_func, &args, num_windows, windows, options, + result); if (ret != 0) { goto out; } @@ -3106,6 +3099,7 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, return ret; } +static int Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) { diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index ceaf45c6df..5c6bd29986 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9642,7 +9642,7 @@ TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, goto out; } w_shape = PyArray_DIMS(weights_array); - if (w_shape[0] != tsk_treeseq_get_num_samples(self->tree_sequence)) { + if (w_shape[0] != (npy_intp) tsk_treeseq_get_num_samples(self->tree_sequence)) { PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); goto out; } diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 628465d9c6..d94d6e9784 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1754,6 +1754,15 @@ def test_window_errors(self): with pytest.raises(_tskit.LibraryError): f(windows=bad_window, **params) + def test_polarisation(self): + ts, f, params = self.get_example() + with pytest.raises(TypeError): + f(polarised="sdf", **params) + x1 = f(polarised=False, **params) + x2 = f(polarised=True, **params) + # Basic check just to run both code paths + assert x1.shape == x2.shape + def test_windows_output(self): ts, f, params = self.get_example() del params["windows"] @@ -2169,6 +2178,20 @@ def test_output_dims(self): out = method(weights[:, [0, 0, 0]], mode=mode, **params) assert out.shape == (1, N, 1) + def test_set_index_errors(self): + ts, method, params = self.get_example() + del params["indexes"] + + def f(indexes): + method(indexes=indexes, **params) + + for bad_array in ["wer", {}, [[[], []], [[], []]]]: + with pytest.raises(ValueError): + f(bad_array) + for bad_dim in [[[]], [[1], [1]]]: + with pytest.raises(ValueError): + f(bad_dim) + class ThreeWaySampleStatsMixin(SampleSetMixin): """ From bb2880902d126a8a4709b17699dabc50ee339efc Mon Sep 17 00:00:00 2001 From: Brieuc Date: Thu, 29 Jun 2023 13:30:49 +0100 Subject: [PATCH 60/84] documentation for genetic_relatedness_weighted First pass at genetic_relatedness_weighted tests Full pass at tests for genetic_relatedness_weighted Update python/tskit/trees.py Co-authored-by: Peter Ralph Add summary func to genetic_relatedness_weighted tests Fix summary function definition in docs --- docs/python-api.md | 1 + docs/stats.md | 7 ++ python/tests/test_tree_stats.py | 199 +++++++++++++++++++++++++++++++- python/tskit/trees.py | 13 ++- 4 files changed, 214 insertions(+), 6 deletions(-) diff --git a/docs/python-api.md b/docs/python-api.md index ac53d3fce9..a8236daadf 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -321,6 +321,7 @@ Single site TreeSequence.Fst TreeSequence.genealogical_nearest_neighbours TreeSequence.genetic_relatedness + TreeSequence.genetic_relatedness_weighted TreeSequence.general_stat TreeSequence.segregating_sites TreeSequence.sample_count_stat diff --git a/docs/stats.md b/docs/stats.md index 39257e1017..72aa5d615b 100644 --- a/docs/stats.md +++ b/docs/stats.md @@ -71,6 +71,7 @@ appears beside the listed method. * Multi-way * {meth}`~TreeSequence.divergence` * {meth}`~TreeSequence.genetic_relatedness` + {meth}`~TreeSequence.genetic_relatedness_weighted` * {meth}`~TreeSequence.f4` {meth}`~TreeSequence.f3` {meth}`~TreeSequence.f2` @@ -593,6 +594,12 @@ and boolean expressions (e.g., {math}`(x > 0)`) are interpreted as 0/1. where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number of samples. +`genetic_relatedness_weighted` +: {math}`f(w_i, w_j, x_i, x_j) = \frac{1}{2}(x_i - w_i m) (x_j - w_j m)`, + + where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number + of samples, and {math}`w_j = \sum_{k=1}^n W_kj` is the sum of the weights in the {math}`j`th column of the weight matrix. + `Y2` : {math}`f(x_1, x_2) = \frac{x_1 (n_2 - x_2) (n_2 - x_2 - 1)}{n_1 n_2 (n_2 - 1)}` diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 7725931b73..a06a690483 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -2101,6 +2101,203 @@ def test_match_K_c0(self): self.assertArrayAlmostEqual(A, B) +############################################ +# Genetic relatedness weighted +############################################ + + +def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): + n = len(sample_sets) + indexes = [ + (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2) + ] + if windows is None: + if mode == "node": + n_nodes = ts.num_nodes + K = np.zeros((n_nodes, n, n)) + out = ts.genetic_relatedness( + sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + ) + for node in range(n_nodes): + this_K = np.zeros((n, n)) + this_K[np.triu_indices(n)] = out[node, :] + this_K = this_K + np.triu(this_K, 1).transpose() + K[node, :, :] = this_K + else: + K = np.zeros((n, n)) + K[np.triu_indices(n)] = ts.genetic_relatedness( + sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + ) + K = K + np.triu(K, 1).transpose() + else: + windows = ts.parse_windows(windows) + n_windows = len(windows) - 1 + out = ts.genetic_relatedness( + sample_sets, + indexes, + mode=mode, + windows=windows, + proportion=False, + span_normalise=True, + ) + if mode == "node": + n_nodes = ts.num_nodes + K = np.zeros((n_windows, n_nodes, n, n)) + for win in range(n_windows): + for node in range(n_nodes): + K_this = np.zeros((n, n)) + K_this[np.triu_indices(n)] = out[win, node, :] + K_this = K_this + np.triu(K_this, 1).transpose() + K[win, node, :, :] = K_this + else: + K = np.zeros((n_windows, n, n)) + for win in range(n_windows): + K_this = np.zeros((n, n)) + K_this[np.triu_indices(n)] = out[win, :] + K_this = K_this + np.triu(K_this, 1).transpose() + K[win, :, :] = K_this + return K + + +def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"): + W_mean = W.mean(axis=0) + W = W - W_mean + sample_sets = [[u] for u in ts.samples()] + K = genetic_relatedness_matrix(ts, sample_sets, windows, mode) + n_indexes = len(indexes) + n_nodes = ts.num_nodes + if windows is None: + if mode == "node": + out = np.zeros((n_nodes, n_indexes)) + else: + out = np.zeros(n_indexes) + else: + windows = ts.parse_windows(windows) + n_windows = len(windows) - 1 + if mode == "node": + out = np.zeros((n_windows, n_nodes, n_indexes)) + else: + out = np.zeros((n_windows, n_indexes)) + for pair in range(n_indexes): + i1 = indexes[pair][0] + i2 = indexes[pair][1] + if windows is None: + if mode == "node": + for node in range(n_nodes): + this_K = K[node, :, :] + out[node, pair] = W[:, i1] @ this_K @ W[:, i2] + else: + out[pair] = W[:, i1] @ K @ W[:, i2] + else: + for win in range(n_windows): + if mode == "node": + for node in range(n_nodes): + this_K = K[win, node, :, :] + out[win, node, pair] = W[:, i1] @ this_K @ W[:, i2] + else: + this_K = K[win, :, :] + out[win, pair] = W[:, i1] @ this_K @ W[:, i2] + return out + + +def example_index_pairs(weights): + assert weights.shape[1] >= 2 + yield [(0, 1)] + yield [(1, 0), (0, 1)] + if weights.shape[1] > 2: + yield [(0, 1), (1, 2), (0, 2)] + + +class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin): + + # Derived classes define this to get a specific stats mode. + mode = None + + def verify_definition( + self, ts, W, indexes, windows, summary_func, ts_method, definition + ): + + # Determine output_dim of the function + M = len(indexes) + + sigma1 = ts.general_stat( + W, summary_func, M, windows, mode=self.mode, span_normalise=True + ) + sigma2 = general_stat( + ts, W, summary_func, windows, mode=self.mode, span_normalise=True + ) + + sigma3 = ts_method( + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + sigma4 = definition( + ts, + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + assert sigma1.shape == sigma2.shape + assert sigma1.shape == sigma3.shape + assert sigma1.shape == sigma4.shape + self.assertArrayAlmostEqual(sigma1, sigma2) + self.assertArrayAlmostEqual(sigma1, sigma3) + self.assertArrayAlmostEqual(sigma1, sigma4) + + def verify(self, ts): + for W, windows in subset_combos( + self.example_weights(ts, min_size=2), example_windows(ts), p=0.1 + ): + for indexes in example_index_pairs(W): + self.verify_weighted_stat(ts, W, indexes, windows) + + def verify_weighted_stat(self, ts, W, indexes, windows): + W_mean = W.mean(axis=0) + W = W - W_mean + W_sum = W.sum(axis=0) + n = W.shape[0] + + def f(x): + mx = np.sum(x) / n + return np.array( + [ + (x[i] - W_sum[i] * mx) * (x[j] - W_sum[j] * mx) / 2 + for i, j in indexes + ] + ) + + self.verify_definition( + ts, + W, + indexes, + windows, + f, + ts.genetic_relatedness_weighted, + genetic_relatedness_weighted, + ) + + +class TestBranchGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, TopologyExamplesMixin +): + mode = "branch" + + +class TestNodeGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, TopologyExamplesMixin +): + mode = "node" + + +class TestSiteGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, MutatedTopologyExamplesMixin +): + mode = "site" + + ############################################ # Fst ############################################ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 356f1fcfe4..67b4db68aa 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8000,11 +8000,14 @@ def genetic_relatedness_weighted( then the k-th column of output will be :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the - {meth}`.genetic_relatedness` between sample i and sample j. - - :param numpy.ndarray W: An array of values with one row for each sample and one - column for each set of weights. - :param list indexes: A list of 2-tuples, or None. + :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample + a and sample b, summing over all pairs of samples in the tree sequence. + + :param numpy.ndarray W: An array of values with one row for each sample node and + one column for each set of weights. + :param list indexes: A list of 2-tuples, or None (default). Note that if + indexes = None, then W must have exactly two columns and this is equivalent + to indexes = [(0,1)]. :param list windows: An increasing list of breakpoints between the windows to compute the statistic in. :param str mode: A string giving the "type" of the statistic to be computed From e20a0a27b8670ff1d03d020a1a53f31ef663f01f Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 13 Jul 2023 09:37:34 +0100 Subject: [PATCH 61/84] fix docs --- python/CHANGELOG.rst | 8 +++++++- python/tskit/trees.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 94ad670759..f529d930d4 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,7 +4,13 @@ **Features** -- Add ``TreeSequence.impute_unknown_mutations_time`` method to return an array of mutation times based on the times of associated nodes (:user:`duncanMR`, :pr:`2760`, :issue:`2758`) +- Add ``TreeSequence.genetic_relatedness_weighted`` stats method. + (:user:`petrelharp`, :user:`brieuclehmann`, :user:`jeromekelleher`, + :pr:`2785`, :pr:`1246`) + +- Add ``TreeSequence.impute_unknown_mutations_time`` method to return an + array of mutation times based on the times of associated nodes + (:user:`duncanMR`, :pr:`2760`, :issue:`2758`) - Add ``asdict`` to all dataclasses. These are returned when you access a row or other tree sequence object. (:user:`benjeffery`, :pr:`2759`, :issue:`2719`) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 67b4db68aa..236a45b8b7 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7996,12 +7996,12 @@ def genetic_relatedness_weighted( polarised=False, ): r""" - Computes weighted genetic relatedness: if the k-th pair of indices is (i, j) + Computes weighted genetic relatedness. If the k-th pair of indices is (i, j) then the k-th column of output will be :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample - a and sample b, summing over all pairs of samples in the tree sequence. + a and sample b, summing over all pairs of samples in the tree sequence. :param numpy.ndarray W: An array of values with one row for each sample node and one column for each set of weights. From 60d75c6035cf17bff2833675485b31c8cbcb7aae Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 13 Jul 2023 11:09:10 +0100 Subject: [PATCH 62/84] Improve error handling on weighted stats --- c/tests/test_stats.c | 84 +++++++++++++++++++++---- c/tskit/core.c | 4 ++ c/tskit/core.h | 4 ++ c/tskit/trees.c | 12 +++- python/tests/test_lowlevel.py | 6 +- python/tests/test_tree_stats.py | 107 +++++++++++++++++++++++++++----- python/tskit/trees.py | 5 +- 7 files changed, 188 insertions(+), 34 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 154c0b6296..39f2a063e0 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -345,6 +345,16 @@ verify_window_errors(tsk_treeseq_t *ts, tsk_flags_t mode) ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + windows[0] = -1; + ret = tsk_treeseq_general_stat( + ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[1] = -1; + ret = tsk_treeseq_general_stat( + ts, 1, W, 1, general_stat_error, NULL, 1, windows, options, sigma); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + windows[0] = 10; ret = tsk_treeseq_general_stat( ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); @@ -438,11 +448,10 @@ verify_node_general_stat_errors(tsk_treeseq_t *ts) static void verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method *method) { - // we don't have any specific errors for this function - // but we might add some in the future int ret; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); double *weights = tsk_malloc(num_samples * sizeof(double)); + double bad_windows[] = { 0, -1 }; double result; tsk_size_t j; @@ -451,7 +460,10 @@ verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method * } ret = method(ts, 0, weights, 0, NULL, 0, &result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 1, weights, 1, bad_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); free(weights); } @@ -460,12 +472,11 @@ static void verify_one_way_weighted_covariate_func_errors( tsk_treeseq_t *ts, one_way_covariates_method *method) { - // we don't have any specific errors for this function - // but we might add some in the future int ret; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); double *weights = tsk_malloc(num_samples * sizeof(double)); double *covariates = NULL; + double bad_windows[] = { 0, -1 }; double result; tsk_size_t j; @@ -474,7 +485,10 @@ verify_one_way_weighted_covariate_func_errors( } ret = method(ts, 0, weights, 0, covariates, 0, NULL, 0, &result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 1, weights, 0, covariates, 1, bad_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); free(weights); } @@ -558,6 +572,28 @@ verify_two_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *m CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); } +static void +verify_two_way_weighted_stat_func_errors( + tsk_treeseq_t *ts, two_way_weighted_method *method) +{ + int ret; + tsk_id_t indexes[] = { 0, 0, 0, 1 }; + double bad_windows[] = { -1, -1 }; + double weights[10]; + double result[10]; + + memset(weights, 0, sizeof(weights)); + + ret = method(ts, 2, weights, 2, indexes, 0, NULL, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = method(ts, 0, weights, 2, indexes, 0, NULL, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 2, weights, 2, indexes, 1, bad_windows, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); +} + static void verify_three_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *method) { @@ -1504,32 +1540,54 @@ test_paper_ex_genetic_relatedness(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_genetic_relatedness_errors(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness); + tsk_treeseq_free(&ts); +} + static void test_paper_ex_genetic_relatedness_weighted(void) { tsk_treeseq_t ts; double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 }; tsk_id_t indexes[] = { 0, 0, 0, 1 }; - double result[2]; + double result[100]; + tsk_size_t num_weights; int ret; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - ret = tsk_treeseq_genetic_relatedness_weighted( - &ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); - CU_ASSERT_EQUAL_FATAL(ret, 0); + for (num_weights = 1; num_weights < 3; num_weights++) { + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_NODE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + tsk_treeseq_free(&ts); } static void -test_paper_ex_genetic_relatedness_errors(void) +test_paper_ex_genetic_relatedness_weighted_errors(void) { tsk_treeseq_t ts; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness); + verify_two_way_weighted_stat_func_errors( + &ts, tsk_treeseq_genetic_relatedness_weighted); tsk_treeseq_free(&ts); } @@ -2128,6 +2186,8 @@ main(int argc, char **argv) { "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness }, { "test_paper_ex_genetic_relatedness_weighted", test_paper_ex_genetic_relatedness_weighted }, + { "test_paper_ex_genetic_relatedness_weighted_errors", + test_paper_ex_genetic_relatedness_weighted_errors }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/core.c b/c/tskit/core.c index 100cc78cad..5a8ed6d9ac 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -475,6 +475,10 @@ tsk_strerror_internal(int err) "statistic. " "(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)"; break; + case TSK_ERR_INSUFFICIENT_WEIGHTS: + ret = "Insufficient weights provided (at least 1 required). " + "(TSK_ERR_INSUFFICIENT_WEIGHTS)"; + break; /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: diff --git a/c/tskit/core.h b/c/tskit/core.h index 4d2c95212d..45a33dd8b7 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -685,6 +685,10 @@ The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does not support it. */ #define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912 +/** +Insufficient weights were provided. +*/ +#define TSK_ERR_INSUFFICIENT_WEIGHTS -913 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index b5cb654a7d..dac3ac154b 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2639,6 +2639,10 @@ tsk_treeseq_trait_covariance(const tsk_treeseq_t *self, tsk_size_t num_weights, ret = TSK_ERR_NO_MEMORY; goto out; } + if (num_weights == 0) { + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; + goto out; + } // center weights for (j = 0; j < num_samples; j++) { @@ -2710,7 +2714,7 @@ tsk_treeseq_trait_correlation(const tsk_treeseq_t *self, tsk_size_t num_weights, } if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; goto out; } @@ -2823,7 +2827,7 @@ tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_weights } if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; goto out; } @@ -3071,6 +3075,10 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, ret = TSK_ERR_NO_MEMORY; goto out; } + if (num_weights == 0) { + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; + goto out; + } // Add a column of ones to W for (j = 0; j < num_samples; j++) { diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index d94d6e9784..c33f159deb 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2153,8 +2153,12 @@ def test_bad_weights(self): del params["weights"] n = ts.get_num_samples() + for bad_weight_type in [None, [None, None]]: + with pytest.raises(ValueError, match="object of too small depth"): + f(weights=bad_weight_type, **params) + for bad_weight_shape in [(n - 1, 1), (n + 1, 1), (0, 3)]: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="First dimension must be num_samples"): f(weights=np.ones(bad_weight_shape), **params) def test_output_dims(self): diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index a06a690483..99e8e11c55 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -453,7 +453,6 @@ def node_summary(u): # contains the location of the last time we updated the output for a node. last_update = np.zeros((ts.num_nodes, 1)) for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): - for edge in edges_out: u = edge.child v = edge.parent @@ -980,7 +979,6 @@ def verify(self, ts): self.verify_weighted_stat(ts, W, windows=windows) def verify_definition(self, ts, W, windows, summary_func, ts_method, definition): - # general_stat will need an extra column for p gW = self.transform_weights(W) @@ -1025,7 +1023,6 @@ def verify(self, ts): def verify_definition( self, ts, sample_sets, windows, summary_func, ts_method, definition ): - W = np.array([[u in A for A in sample_sets] for u in ts.samples()], dtype=float) def wrapped_summary_func(x): @@ -1762,7 +1759,6 @@ def divergence( class TestDivergence(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -1974,7 +1970,6 @@ def genetic_relatedness( class TestGeneticRelatedness(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2035,7 +2030,6 @@ def wrapped_summary_func(x): self.assertArrayAlmostEqual(sigma1, sigma4) def verify_sample_sets_indexes(self, ts, sample_sets, indexes, windows): - n = np.array([len(x) for x in sample_sets]) n_total = sum(n) @@ -2209,14 +2203,12 @@ def example_index_pairs(weights): class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None def verify_definition( self, ts, W, indexes, windows, summary_func, ts_method, definition ): - # Determine output_dim of the function M = len(indexes) @@ -2298,6 +2290,96 @@ class TestSiteGeneticRelatednessWeighted( mode = "site" +# NOTE: these classes don't follow the same (anti)-patterns as used elsewhere as they +# were added in several years afterwards. + + +class TestGeneticRelatednessWeightedSimpleExamples: + # Values verified against the simple implementations above + site_value = 11.12 + branch_value = 14.72 + + def fixture(self): + ts = tskit.Tree.generate_balanced(5).tree_sequence + # Abitrary weights that give non-zero results + W = np.zeros((ts.num_samples, 2)) + W[0, :] = 1 + W[1, :] = 2 + return tsutil.insert_branch_sites(ts), W + + def test_no_arguments_site(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="site") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.site_value) + + def test_windows_site(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="site", windows=[0, 1 - 1e-12, 1]) + assert X.shape == (2,) + nt.assert_almost_equal(X[0], self.site_value) + nt.assert_almost_equal(X[1], 0) + + def test_no_arguments_branch(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="branch") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.branch_value) + + def test_windows_branch(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="branch", windows=[0, 0.5, 1]) + assert X.shape == (2,) + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_1D(self): + ts, W = self.fixture() + indexes = [0, 1] + X = ts.genetic_relatedness_weighted(W, indexes, mode="branch") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_2D(self): + ts, W = self.fixture() + indexes = [[0, 1]] + X = ts.genetic_relatedness_weighted(W, indexes, mode="branch") + assert X.shape == (1,) + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_2D_windows(self): + ts, W = self.fixture() + indexes = [[0, 1], [0, 1]] + X = ts.genetic_relatedness_weighted( + W, indexes, windows=[0, 0.5, 1], mode="branch" + ) + assert X.shape == (2, 2) + nt.assert_almost_equal(X, self.branch_value) + + +class TestGeneticRelatednessWeightedErrors: + def ts(self): + return tskit.Tree.generate_balanced(3).tree_sequence + + @pytest.mark.parametrize("W", [[0], np.array([0]), np.zeros(100)]) + def test_bad_weight_size(self, W): + with pytest.raises(ValueError, match="First trait dimension"): + self.ts().genetic_relatedness_weighted(W) + + @pytest.mark.parametrize("cols", [1, 3]) + def test_no_indexes_with_non_2_cols(self, cols): + ts = self.ts() + W = np.zeros((ts.num_samples, cols)) + with pytest.raises(ValueError, match="Must specify indexes"): + ts.genetic_relatedness_weighted(W) + + @pytest.mark.parametrize("indexes", [[], [[0]], [[0, 0, 0]], [[[0], [0], [0]]]]) + def test_bad_index_shapes(self, indexes): + ts = self.ts() + W = np.zeros((ts.num_samples, 2)) + with pytest.raises(ValueError, match="Indexes must be convertable to a 2D"): + ts.genetic_relatedness_weighted(W, indexes=indexes) + + ############################################ # Fst ############################################ @@ -2340,7 +2422,6 @@ def single_site_Fst(ts, sample_sets, indexes): class TestFst(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2529,7 +2610,6 @@ def Y2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class TestY2(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2702,7 +2782,6 @@ def Y3(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class TestY3(StatsTestCase, ThreeWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2871,7 +2950,6 @@ def f2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf2(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3057,7 +3135,6 @@ def f3(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf3(StatsTestCase, ThreeWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3248,7 +3325,6 @@ def f4(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf4(StatsTestCase, FourWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3512,7 +3588,6 @@ def update_result(window_index, u, right): last_update[u] = right for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): - for edge in edges_out: u = edge.child v = edge.parent @@ -3673,7 +3748,6 @@ def allele_frequency_spectrum( class TestAlleleFrequencySpectrum(StatsTestCase, SampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -6003,7 +6077,6 @@ def f(x): branch_true_diversity_02, ], ): - self.assertAlmostEqual(diversity(ts, A, mode=mode)[0][0], truth) self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0], truth) self.assertAlmostEqual(ts.diversity(A, mode="branch")[0], truth) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 236a45b8b7..10b3bd9782 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7569,10 +7569,11 @@ def __k_way_weighted_stat( span_normalise=True, polarised=False, ): + W = np.asarray(W) if indexes is None: if W.shape[1] != k: raise ValueError( - "Must specify indexes if there are not exactly {} columsn " + "Must specify indexes if there are not exactly {} columns " "in W.".format(k) ) indexes = np.arange(k, dtype=np.int32) @@ -8016,7 +8017,7 @@ def genetic_relatedness_weighted( window (defaults to True). :return: A ndarray with shape equal to (num windows, num statistics). """ - if W.shape[0] != self.num_samples: + if len(W) != self.num_samples: raise ValueError( "First trait dimension must be equal to number of samples." ) From f7ba5489ae9fa7bede54ad856f181c85f8759f6e Mon Sep 17 00:00:00 2001 From: astheeggeggs Date: Wed, 31 Aug 2022 16:57:54 +0100 Subject: [PATCH 63/84] Add backwards algorithm for haploid data, using lshmm for testing --- .github/workflows/tests.yml | 2 +- .../CI-tests-pip/requirements.txt | 2 +- python/tests/test_genotype_matching_fb.py | 1 - python/tests/test_haplotype_matching.py | 1392 +++++++---------- 4 files changed, 603 insertions(+), 794 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 473f04e893..58eb68a9d3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -83,7 +83,7 @@ jobs: /usr/share/miniconda/envs/anaconda-client-env ~/osx-conda ~/.profile - key: ${{ runner.os }}-${{ matrix.python}}-conda-v11-${{ hashFiles('python/requirements/CI-tests-conda/requirements.txt') }}-${{ hashFiles('python/requirements/CI-tests-pip/requirements.txt') }} + key: ${{ runner.os }}-${{ matrix.python}}-conda-v12-${{ hashFiles('python/requirements/CI-tests-conda/requirements.txt') }}-${{ hashFiles('python/requirements/CI-tests-pip/requirements.txt') }} - name: Install Conda uses: conda-incubator/setup-miniconda@v2 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index e9e16c64e5..9f3b31ef3d 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -1,4 +1,4 @@ -lshmm==0.0.4; python_version < '3.11' +lshmm==0.0.4 numpy==1.21.6; python_version < '3.11' # Held at 1.21.6 for Python 3.7 compatibility numpy==1.24.1; python_version > '3.10' pytest==7.1.3 diff --git a/python/tests/test_genotype_matching_fb.py b/python/tests/test_genotype_matching_fb.py index 248382e913..761eadf403 100644 --- a/python/tests/test_genotype_matching_fb.py +++ b/python/tests/test_genotype_matching_fb.py @@ -754,7 +754,6 @@ def compute_next_probability_dict( query_is_missing, ): mu = self.mu[site_id] - template_is_hom = np.logical_not(template_is_het) if query_is_missing: diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 55f102939c..b09ebcc005 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2021 Tskit Developers +# Copyright (c) 2019-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,332 +20,55 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """ -Python implementation of the Li and Stephens algorithms. +Python implementation of the Li and Stephens forwards and backwards algorithms. """ import itertools -import unittest +import lshmm as ls import msprime import numpy as np -import pytest -import _tskit # TMP import tskit -from tests import tsutil +MISSING = -1 -def in_sorted(values, j): - # Take advantage of the fact that the numpy array is sorted. - ret = False - index = np.searchsorted(values, j) - if index < values.shape[0]: - ret = values[index] == j - return ret - -def ls_forward_matrix_naive(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS forward algorithm using Python loops. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - S = np.zeros(m) - f = np.zeros(n) + 1 / n - - for el in range(0, m): - for j in range(n): - # NOTE Careful with the difference between this expression and - # the Viterbi algorithm below. This depends on the different - # normalisation approach. - p_t = f[j] * (1 - rho[el]) + rho[el] / n - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - f[j] = p_t * p_e - S[el] = np.sum(f) - # TODO need to handle the 0 case. - assert S[el] > 0 - f /= S[el] - F[el] = f - return F, S - - -def ls_viterbi_naive(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS Viterbi algorithm using Python loops. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - L = np.ones(n) - T = [set() for _ in range(m)] - T_dest = np.zeros(m, dtype=int) - - for el in range(m): - # The calculation below is undefined otherwise. - if len(alleles[el]) > 1: - assert mu[el] <= 1 / (len(alleles[el]) - 1) - L_next = np.zeros(n) - for j in range(n): - # NOTE Careful with the difference between this expression and - # the Forward algorithm above. This depends on the different - # normalisation approach. - p_no_recomb = L[j] * (1 - rho[el] + rho[el] / n) - p_recomb = rho[el] / n - if p_no_recomb > p_recomb: - p_t = p_no_recomb - else: - p_t = p_recomb - T[el].add(j) - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - L_next[j] = p_t * p_e - L = L_next - j = np.argmax(L) - T_dest[el] = j - if L[j] == 0: - assert mu[el] == 0 - raise ValueError( - "Trying to match non-existent allele with zero mutation rate" - ) - L /= L[j] - - P = np.zeros(m, dtype=int) - P[m - 1] = T_dest[m - 1] - for el in range(m - 1, 0, -1): - j = P[el] - if j in T[el]: - j = T_dest[el - 1] - P[el - 1] = j - return P - - -def ls_viterbi_vectorised(h, alleles, G, rho, mu): - # We must have a non-zero mutation rate, or we'll end up with - # division by zero problems. - # assert np.all(mu > 0) - - m, n = G.shape - alleles = check_alleles(alleles, m) - V = np.ones(n) - T = [None for _ in range(m)] - max_index = np.zeros(m, dtype=int) - - for site in range(m): - # Transition - p_neq = rho[site] / n - p_t = (1 - rho[site] + rho[site] / n) * V - recombinations = np.where(p_neq > p_t)[0] - p_t[recombinations] = p_neq - T[site] = recombinations - # Emission - p_e = np.zeros(n) + mu[site] - index = G[site] == h[site] - if h[site] == tskit.MISSING_DATA: - # Missing data is considered equal to everything - index[:] = True - p_e[index] = 1 - (len(alleles[site]) - 1) * mu[site] - V = p_t * p_e - # Normalise - max_index[site] = np.argmax(V) - # print(site, ":", V) - if V[max_index[site]] == 0: - assert mu[site] == 0 - raise ValueError( - "Trying to match non-existent allele with zero mutation rate" - ) - V /= V[max_index[site]] - - # Traceback - P = np.zeros(m, dtype=int) - site = m - 1 - P[site] = max_index[site] - while site > 0: - j = P[site] - if in_sorted(T[site], j): - j = max_index[site - 1] - P[site - 1] = j - site -= 1 - return P - - -def check_alleles(alleles, num_sites): +def check_alleles(alleles, m): """ Checks the specified allele list and returns a list of lists of alleles of length num_sites. - If alleles is a 1D list of strings, assume that this list is used for each site and return num_sites copies of this list. - Otherwise, raise a ValueError if alleles is not a list of length num_sites. """ if isinstance(alleles[0], str): - return [alleles for _ in range(num_sites)] - if len(alleles) != num_sites: + return [alleles for _ in range(m)], np.int8([len(alleles) for _ in range(m)]) + if len(alleles) != m: raise ValueError("Malformed alleles list") - return alleles + n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles]) + return alleles, n_alleles -def ls_forward_matrix(h, alleles, G, rho, mu): +def mirror_coordinates(ts): """ - Simple matrix based method for LS forward algorithm using numpy vectorisation. + Returns a copy of the specified tree sequence in which all + coordinates x are transformed into L - x. """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - S = np.zeros(m) - f = np.zeros(n) + 1 / n - p_e = np.zeros(n) - - for el in range(0, m): - p_t = f * (1 - rho[el]) + rho[el] / n - eq = G[el] == h[el] - if h[el] == tskit.MISSING_DATA: - # Missing data is equal to everything - eq[:] = True - p_e[:] = mu[el] - p_e[eq] = 1 - (len(alleles[el]) - 1) * mu[el] - f = p_t * p_e - S[el] = np.sum(f) - # TODO need to handle the 0 case. - assert S[el] > 0 - f /= S[el] - F[el] = f - return F, S - - -def forward_matrix_log_proba(F, S): - """ - Given the specified forward matrix and scaling factor array, return the - overall log probability of the input haplotype. - """ - return np.sum(np.log(S)) - np.log(np.sum(F[-1])) - - -def ls_forward_matrix_unscaled(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS forward algorithm. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - f = np.zeros(n) + 1 / n - - for el in range(0, m): - s = np.sum(f) - for j in range(n): - p_t = f[j] * (1 - rho[el]) + s * rho[el] / n - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - f[j] = p_t * p_e - F[el] = f - return F - - -# TODO change this to use the log_proba function below. -def ls_path_probability(h, path, G, rho, mu): - """ - Returns the probability of the specified path through the genotypes for the - specified haplotype. - """ - # Assuming num_alleles = 2 - assert rho[0] == 0 - m, n = G.shape - # TODO It's not entirely clear why we're starting with a proba of 1 / n for the - # model. This was done because it made it easier to compare with an existing - # HMM implementation. Need to figure this one out when writing up. - proba = 1 / n - for site in range(0, m): - pe = mu[site] - if h[site] == G[site, path[site]] or h[site] == tskit.MISSING_DATA: - pe = 1 - mu[site] - pt = rho[site] / n - if site == 0 or path[site] == path[site - 1]: - pt = 1 - rho[site] + rho[site] / n - proba *= pt * pe - return proba - - -def ls_path_log_probability(h, path, alleles, G, rho, mu): - """ - Returns the log probability of the specified path through the genotypes for the - specified haplotype. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - # TODO It's not entirely clear why we're starting with a proba of 1 / n for the - # model. This was done because it made it easier to compare with an existing - # HMM implementation. Need to figure this one out when writing up. - log_proba = np.log(1 / n) - for site in range(0, m): - if len(alleles[site]) > 1: - assert mu[site] <= 1 / (len(alleles[site]) - 1) - pe = mu[site] - if h[site] == G[site, path[site]] or h[site] == tskit.MISSING_DATA: - pe = 1 - (len(alleles[site]) - 1) * mu[site] - assert 0 <= pe <= 1 - pt = rho[site] / n - if site == 0 or path[site] == path[site - 1]: - pt = 1 - rho[site] + rho[site] / n - assert 0 <= pt <= 1 - log_proba += np.log(pt) + np.log(pe) - return log_proba - - -def ls_forward_tree(h, alleles, ts, rho, mu, precision=30, use_lib=True): - """ - Forward matrix computation based on a tree sequence. - """ - if use_lib: - acgt_alleles = tuple(alleles) == tskit.ALLELES_ACGT - ls_hmm = _tskit.LsHmm( - ts.ll_tree_sequence, - recombination_rate=rho, - mutation_rate=mu, - precision=precision, - acgt_alleles=acgt_alleles, - ) - cm = _tskit.CompressedMatrix(ts.ll_tree_sequence) - ls_hmm.forward_matrix(h, cm) - return cm - else: - fa = ForwardAlgorithm(ts, rho, mu, alleles, precision=precision) - return fa.run(h) - - -def ls_viterbi_tree(h, alleles, ts, rho, mu, precision=30, use_lib=True): - """ - Viterbi path computation based on a tree sequence. - """ - if use_lib: - acgt_alleles = tuple(alleles) == tskit.ALLELES_ACGT - ls_hmm = _tskit.LsHmm( - ts.ll_tree_sequence, - recombination_rate=rho, - mutation_rate=mu, - precision=precision, - acgt_alleles=acgt_alleles, - ) - vm = _tskit.ViterbiMatrix(ts.ll_tree_sequence) - ls_hmm.viterbi_matrix(h, vm) - return vm - else: - va = ViterbiAlgorithm(ts, rho, mu, alleles, precision=precision) - return va.run(h) + L = ts.sequence_length + tables = ts.dump_tables() + left = tables.edges.left + right = tables.edges.right + tables.edges.left = L - right + tables.edges.right = L - left + tables.sites.position = L - tables.sites.position # + 1 + # TODO migrations. + tables.sort() + return tables.tree_sequence() class ValueTransition: - """ - Simple struct holding value transition values. - """ + """Simple struct holding value transition values.""" def __init__(self, tree_node=-1, value=-1, value_index=-1): self.tree_node = tree_node @@ -353,7 +76,11 @@ def __init__(self, tree_node=-1, value=-1, value_index=-1): self.value_index = value_index def copy(self): - return ValueTransition(self.tree_node, self.value, self.value_index) + return ValueTransition( + self.tree_node, + self.value, + self.value_index, + ) def __repr__(self): return repr(self.__dict__) @@ -367,11 +94,12 @@ class LsHmmAlgorithm: Abstract superclass of Li and Stephens HMM algorithm. """ - def __init__(self, ts, rho, mu, alleles, precision=10): + def __init__( + self, ts, rho, mu, alleles, n_alleles, precision=10, scale_mutation=False + ): self.ts = ts self.mu = mu self.rho = rho - self.alleles = check_alleles(alleles, ts.num_sites) self.precision = precision # The array of ValueTransitions. self.T = [] @@ -386,6 +114,10 @@ def __init__(self, ts, rho, mu, alleles, precision=10): self.parent = np.zeros(self.ts.num_nodes, dtype=int) - 1 self.tree = tskit.Tree(self.ts) self.output = None + # Vector of the number of alleles at each site + self.n_alleles = n_alleles + self.alleles = alleles + self.scale_mutation_based_on_n_alleles = scale_mutation def check_integrity(self): M = [st.tree_node for st in self.T if st.tree_node != -1] @@ -422,10 +154,6 @@ def compute(u, parent_state): for j in range(num_values): value_count[j] += child[j] max_value_count = np.max(value_count) - # NOTE: we need to set the set to zero here because we actually - # visit some nodes more than once during the postorder traversal. - # This would seem to be wasteful, so we should revisit this when - # cleaning up the algorithm logic. optimal_set[u, :] = 0 optimal_set[u, value_count == max_value_count] = 1 @@ -566,9 +294,9 @@ def update_probabilities(self, site, haplotype_state): T = self.T alleles = self.alleles[site.id] allelic_state = self.allelic_state - # Set the allelic_state for this site. allelic_state[tree.root] = alleles.index(site.ancestral_state) + for mutation in site.mutations: u = mutation.node allelic_state[u] = alleles.index(mutation.derived_state) @@ -590,8 +318,7 @@ def update_probabilities(self, site, haplotype_state): v = tree.parent(v) assert v != -1 match = ( - haplotype_state == tskit.MISSING_DATA - or haplotype_state == allelic_state[v] + haplotype_state == MISSING or haplotype_state == allelic_state[v] ) st.value = self.compute_next_probability(site.id, st.value, match, u) @@ -600,31 +327,41 @@ def update_probabilities(self, site, haplotype_state): for mutation in site.mutations: allelic_state[mutation.node] = -1 - def process_site(self, site, haplotype_state): - # print(site.id, "num_transitions=", len(self.T)) - self.update_probabilities(site, haplotype_state) - # FIXME We don't want to call compress here. - # What we really want to do is just call compress after - # the values have been normalised and rounded. However, we can't - # compute the normalisation factor in the forwards algorithm without - # the N counts (number of samples directly below each value transition - # in T), and these are currently computed during compress. So to make - # things work for now we call compress before and put up with having - # a slightly less than optimally compressed output matrix. It might - # end up that this makes no difference and compressing the - # pre-rounded values is basically the same thing. - self.compress() - s = self.compute_normalisation_factor() - for st in self.T: - if st.tree_node != tskit.NULL: - st.value /= s - st.value = round(st.value, self.precision) - # *This* is where we want to compress (and can, for viterbi). - # self.compress() - self.output.store_site(site.id, s, [(st.tree_node, st.value) for st in self.T]) - - def run(self, h): + def process_site(self, site, haplotype_state, forwards=True): + if forwards: + # Forwards algorithm, or forwards pass in Viterbi + self.update_probabilities(site, haplotype_state) + self.compress() + s = self.compute_normalisation_factor() + for st in self.T: + if st.tree_node != tskit.NULL: + st.value /= s + st.value = round(st.value, self.precision) + self.output.store_site( + site.id, s, [(st.tree_node, st.value) for st in self.T] + ) + else: + # Backwards algorithm + self.output.store_site( + site.id, + self.output.normalisation_factor[site.id], + [(st.tree_node, st.value) for st in self.T], + ) + self.update_probabilities(site, haplotype_state) + self.compress() + b_last_sum = self.compute_normalisation_factor() + s = self.output.normalisation_factor[site.id] + for st in self.T: + if st.tree_node != tskit.NULL: + st.value = ( + self.rho[site.id] / self.ts.num_samples + ) * b_last_sum + (1 - self.rho[site.id]) * st.value + st.value /= s + st.value = round(st.value, self.precision) + + def run_forward(self, h): n = self.ts.num_samples + self.tree.clear() for u in self.ts.samples(): self.T_index[u] = len(self.T) self.T.append(ValueTransition(tree_node=u, value=1 / n)) @@ -634,6 +371,17 @@ def run(self, h): self.process_site(site, h[site.id]) return self.output + def run_backward(self, h): + self.tree.clear() + for u in self.ts.samples(): + self.T_index[u] = len(self.T) + self.T.append(ValueTransition(tree_node=u, value=1)) + while self.tree.next(): + self.update_tree() + for site in self.tree.sites(): + self.process_site(site, h[site.id], forwards=False) + return self.output + def compute_normalisation_factor(self): raise NotImplementedError() @@ -650,12 +398,16 @@ class CompressedMatrix: values are on the path). """ - def __init__(self, ts): + def __init__(self, ts, normalisation_factor=None): self.ts = ts self.num_sites = ts.num_sites self.num_samples = ts.num_samples self.value_transitions = [None for _ in range(self.num_sites)] - self.normalisation_factor = np.zeros(self.num_sites) + if normalisation_factor is None: + self.normalisation_factor = np.zeros(self.num_sites) + else: + self.normalisation_factor = normalisation_factor + assert len(self.normalisation_factor) == self.num_sites def store_site(self, site, normalisation_factor, value_transitions): self.normalisation_factor[site] = normalisation_factor @@ -688,39 +440,11 @@ def decode(self): class ForwardMatrix(CompressedMatrix): - """ - Class representing a compressed forward matrix. - """ - - -class ForwardAlgorithm(LsHmmAlgorithm): - """ - Runs the Li and Stephens forward algorithm. - """ - - def __init__(self, ts, rho, mu, alleles, precision=10): - super().__init__(ts, rho, mu, alleles, precision) - self.output = ForwardMatrix(ts) - - def compute_normalisation_factor(self): - s = 0 - for j, st in enumerate(self.T): - assert st.tree_node != tskit.NULL - assert self.N[j] > 0 - s += self.N[j] * st.value - return s + """Class representing a compressed forward matrix.""" - def compute_next_probability(self, site_id, p_last, is_match, node): - rho = self.rho[site_id] - mu = self.mu[site_id] - alleles = self.alleles[site_id] - n = self.ts.num_samples - p_t = p_last * (1 - rho) + rho / n - p_e = mu - if is_match: - p_e = 1 - (len(alleles) - 1) * mu - return p_t * p_e +class BackwardMatrix(CompressedMatrix): + """Class representing a compressed backward matrix.""" class ViterbiMatrix(CompressedMatrix): @@ -730,6 +454,8 @@ class ViterbiMatrix(CompressedMatrix): def __init__(self, ts): super().__init__(ts) + # Tuple containing the site, the node in the tree, and whether + # recombination is required self.recombination_required = [(-1, 0, False)] def add_recombination_required(self, site, node, required): @@ -801,13 +527,144 @@ def traceback(self): return match +class ForwardAlgorithm(LsHmmAlgorithm): + """Runs the Li and Stephens forward algorithm.""" + + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = ForwardMatrix(ts) + + def compute_normalisation_factor(self): + s = 0 + for j, st in enumerate(self.T): + assert st.tree_node != tskit.NULL + assert self.N[j] > 0 + s += self.N[j] * st.value + return s + + def compute_next_probability( + self, site_id, p_last, is_match, node + ): # Note node only used in Viterbi + rho = self.rho[site_id] + mu = self.mu[site_id] + n = self.ts.num_samples + n_alleles = self.n_alleles[site_id] + + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + + p_t = p_last * (1 - rho) + rho / n + return p_t * p_e + + +class BackwardAlgorithm(LsHmmAlgorithm): + """Runs the Li and Stephens backward algorithm.""" + + def __init__( + self, + ts, + rho, + mu, + alleles, + n_alleles, + normalisation_factor, + scale_mutation=False, + precision=10, + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = BackwardMatrix(ts, normalisation_factor) + + def compute_normalisation_factor(self): + s = 0 + for j, st in enumerate(self.T): + assert st.tree_node != tskit.NULL + assert self.N[j] > 0 + s += self.N[j] * st.value + return s + + def compute_next_probability( + self, site_id, p_next, is_match, node + ): # Note node only used in Viterbi + mu = self.mu[site_id] + n_alleles = self.n_alleles[site_id] + + if self.scale_mutation_based_on_n_alleles: + if is_match: + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + else: + if n_alleles == 1: + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + return p_next * p_e + + class ViterbiAlgorithm(LsHmmAlgorithm): """ Runs the Li and Stephens Viterbi algorithm. """ - def __init__(self, ts, rho, mu, alleles, precision=10): - super().__init__(ts, rho, mu, alleles, precision) + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) self.output = ViterbiMatrix(ts) def compute_normalisation_factor(self): @@ -825,8 +682,8 @@ def compute_normalisation_factor(self): def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] mu = self.mu[site_id] - alleles = self.alleles[site_id] n = self.ts.num_samples + n_alleles = self.n_alleles[site_id] p_no_recomb = p_last * (1 - rho + rho / n) p_recomb = rho / n @@ -837,474 +694,427 @@ def compute_next_probability(self, site_id, p_last, is_match, node): p_t = p_recomb recombination_required = True self.output.add_recombination_required(site_id, node, recombination_required) - p_e = mu - if is_match: - p_e = 1 - (len(alleles) - 1) * mu - return p_t * p_e + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) -################################################################ -# Tests -################################################################ + return p_t * p_e -class LiStephensBase: +def ls_forward_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) + for j in range(ts.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts.num_sites) + + """Forward matrix computation based on a tree sequence.""" + fa = ForwardAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + return fa.run_forward(h) + + +def ls_backward_tree( + h, ts_mirror, rho, mu, normalisation_factor, precision=30, alleles=None +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts_mirror.genotype_matrix()[j, :], h[j]))) + for j in range(ts_mirror.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts_mirror.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts_mirror.num_sites) + + """Backward matrix computation based on a tree sequence.""" + ba = BackwardAlgorithm( + ts_mirror, + rho, + mu, + alleles, + n_alleles, + normalisation_factor, + precision=precision, + ) + return ba.run_backward(h) + + +def ls_viterbi_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) + for j in range(ts.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts.num_sites) """ - Superclass of Li and Stephens tests. + Viterbi path computation based on a tree sequence. """ - - def assertCompressedMatricesEqual(self, cm1, cm2): - """ - Checks that the specified compressed matrices contain the same data. - """ - A1 = cm1.decode() - A2 = cm2.decode() - assert np.allclose(A1, A2) - assert A1.shape == A2.shape - assert cm1.num_sites == cm2.num_sites - nf1 = cm1.normalisation_factor - nf2 = cm1.normalisation_factor - assert np.allclose(nf1, nf2) - assert nf1.shape == nf2.shape - # It seems that we can't rely on the number of transitions in the two - # implementations being equal, which seems odd given that we should - # be doing things identically. Still, once the decoded matrices are the - # same then it seems highly likely to be correct. - - # if not np.array_equal(cm1.num_transitions, cm2.num_transitions): - # print() - # print(cm1.num_transitions) - # print(cm2.num_transitions) - # self.assertTrue(np.array_equal(cm1.num_transitions, cm2.num_transitions)) - # for j in range(cm1.num_sites): - # s1 = dict(cm1.get_site(j)) - # s2 = dict(cm2.get_site(j)) - # self.assertEqual(set(s1.keys()), set(s2.keys())) - # for key in s1.keys(): - # self.assertAlmostEqual(s1[key], s2[key]) - - def example_haplotypes(self, ts, alleles, num_random=10, seed=2): - rng = np.random.RandomState(seed) - H = ts.genotype_matrix(alleles=alleles).T - haplotypes = [H[0], H[-1]] - for _ in range(num_random): - # Choose a random path through H - p = rng.randint(0, ts.num_samples, ts.num_sites) - h = H[p, np.arange(ts.num_sites)] - haplotypes.append(h) - h = H[0].copy() - h[-1] = tskit.MISSING_DATA - haplotypes.append(h) - h = H[0].copy() - h[ts.num_sites // 2] = tskit.MISSING_DATA - haplotypes.append(h) - # All missing is OK tool - h = H[0].copy() - h[:] = tskit.MISSING_DATA - haplotypes.append(h) - return haplotypes - - def example_parameters(self, ts, alleles, seed=1): - """ - Returns an iterator over combinations of haplotype, recombination and mutation - rates. - """ - rng = np.random.RandomState(seed) - haplotypes = self.example_haplotypes(ts, alleles, seed=seed) - - # This is the exact matching limit. - rho = np.zeros(ts.num_sites) + 0.01 - mu = np.zeros(ts.num_sites) - rho[0] = 0 - for h in haplotypes: - yield h, rho, mu + va = ViterbiAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + return va.run_forward(h) + + +class LSBase: + """Superclass of Li and Stephens tests.""" + + def example_haplotypes(self, ts): + + H = ts.genotype_matrix() + s = H[:, 0].reshape(1, H.shape[0]) + H = H[:, 1:] + + haplotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]), + ] + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + haplotypes.append(s_tmp) + + return H, haplotypes + + def example_parameters_haplotypes(self, ts, seed=42): + """Returns an iterator over combinations of haplotype, + recombination and mutation rates.""" + np.random.seed(seed) + H, haplotypes = self.example_haplotypes(ts) + n = H.shape[1] + m = ts.get_num_sites() # Here we have equal mutation and recombination - rho = np.zeros(ts.num_sites) + 0.01 - mu = np.zeros(ts.num_sites) + 0.01 - rho[0] = 0 - for h in haplotypes: - yield h, rho, mu + r = np.zeros(m) + 0.01 + mu = np.zeros(m) + 0.01 + r[0] = 0 + + for s in haplotypes: + yield n, H, s, r, mu # Mixture of random and extremes - rhos = [ - np.zeros(ts.num_sites) + 0.999, - np.zeros(ts.num_sites) + 1e-6, - rng.uniform(0, 1, ts.num_sites), - ] - # mu can't be more than 1 / 3 if we have 4 alleles - mus = [ - np.zeros(ts.num_sites) + 0.33, - np.zeros(ts.num_sites) + 1e-6, - rng.uniform(0, 0.33, ts.num_sites), - ] - for h, rho, mu in itertools.product(haplotypes, rhos, mus): - rho[0] = 0 - yield h, rho, mu + rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] + mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] + + for s, r, mu in itertools.product(haplotypes, rs, mus): + r[0] = 0 + yield n, H, s, r, mu def assertAllClose(self, A, B): - assert np.allclose(A, B) + """Assert that all entries of two matrices are 'close'""" + assert np.allclose(A, B, rtol=1e-5, atol=1e-8) + + # Define a bunch of very small tree-sequences for testing a collection + # of parameters on + def test_simple_n_10_no_recombination(self): + ts = msprime.simulate( + 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 + ) + assert ts.num_sites > 3 + self.verify(ts) - def test_simple_n_4_no_recombination(self): - ts = msprime.simulate(4, recombination_rate=0, mutation_rate=0.5, random_seed=1) + def test_simple_n_10_no_recombination_high_mut(self): + ts = msprime.simulate(10, recombination_rate=0, mutation_rate=3, random_seed=42) assert ts.num_sites > 3 self.verify(ts) - def test_simple_n_3(self): - ts = msprime.simulate(3, recombination_rate=2, mutation_rate=7, random_seed=2) - assert ts.num_sites > 5 + def test_simple_n_10_no_recombination_higher_mut(self): + ts = msprime.simulate(20, recombination_rate=0, mutation_rate=3, random_seed=42) + assert ts.num_sites > 3 self.verify(ts) - def test_simple_n_7(self): - ts = msprime.simulate(7, recombination_rate=2, mutation_rate=5, random_seed=2) + def test_simple_n_6(self): + ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42) assert ts.num_sites > 5 self.verify(ts) - def test_simple_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=2) - assert ts.num_trees > 15 + def test_simple_n_8(self): + ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42) assert ts.num_sites > 5 self.verify(ts) - def test_simple_n_15(self): - ts = msprime.simulate(15, recombination_rate=2, mutation_rate=5, random_seed=2) + def test_simple_n_8_high_recombination(self): + ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42) + assert ts.num_trees > 15 assert ts.num_sites > 5 self.verify(ts) - def test_jukes_cantor_n_3(self): - ts = msprime.simulate(3, mutation_rate=2, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=10, seed=4) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=20, mu=5, seed=4) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_n_15(self): - ts = msprime.simulate(15, mutation_rate=2, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=0.1, seed=10) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_balanced_ternary(self): - ts = tskit.Tree.generate_balanced(27, arity=3).tree_sequence - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=0.1, seed=10) - self.verify(ts, tskit.ALLELES_ACGT) - - @pytest.mark.skip(reason="Not supporting internal samples yet") - def test_ancestors_n_3(self): - ts = msprime.simulate(3, recombination_rate=2, mutation_rate=7, random_seed=2) + def test_simple_n_16(self): + ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42) assert ts.num_sites > 5 - tables = ts.dump_tables() - print(tables.nodes) - tables.nodes.flags = np.ones_like(tables.nodes.flags) - print(tables.nodes) - ts = tables.tree_sequence() self.verify(ts) + # # Define a bunch of very small tree-sequences for testing a collection + # # of parameters on + # def test_simple_n_10_no_recombination_blah(self): + # ts = msprime.sim_ancestry( + # samples=10, + # recombination_rate=0, + # random_seed=42, + # sequence_length=10, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-5, random_seed=42) + # assert ts.num_sites > 3 + # self.verify(ts) + + # def test_simple_n_6_blah(self): + # ts = msprime.sim_ancestry( + # samples=6, + # recombination_rate=1e-4, + # random_seed=42, + # sequence_length=40, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-3, random_seed=42) + # assert ts.num_sites > 5 + # self.verify(ts) + + # def test_simple_n_8_blah(self): + # ts = msprime.sim_ancestry( + # samples=8, + # recombination_rate=1e-4, + # random_seed=42, + # sequence_length=20, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + # assert ts.num_sites > 5 + # assert ts.num_trees > 15 + # self.verify(ts) + + # def test_simple_n_16_blah(self): + # ts = msprime.sim_ancestry( + # samples=16, + # recombination_rate=1e-2, + # random_seed=42, + # sequence_length=20, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + # assert ts.num_sites > 5 + # self.verify(ts) + + def verify(self, ts): + raise NotImplementedError() -@pytest.mark.slow -class ForwardAlgorithmBase(LiStephensBase): - """ - Base for forward algorithm tests. - """ +class FBAlgorithmBase(LSBase): + """Base for forwards backwards algorithm tests.""" -class TestNumpyMatrixMethod(ForwardAlgorithmBase): - """ - Tests that we compute the same values from the numpy matrix method as - the naive algorithm. - """ - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - for h, rho, mu in self.example_parameters(ts, alleles): - F1, S1 = ls_forward_matrix(h, alleles, G, rho, mu) - F2, S2 = ls_forward_matrix_naive(h, alleles, G, rho, mu) - self.assertAllClose(F1, F2) - self.assertAllClose(S1, S2) +class VitAlgorithmBase(LSBase): + """Base for viterbi algoritm tests.""" -class ViterbiAlgorithmBase(LiStephensBase): - """ - Base for viterbi algoritm tests. - """ +class TestMirroringHap(FBAlgorithmBase): + """Tests that mirroring the tree sequence and running forwards and backwards + algorithms gives the same log-likelihood of observing the data.""" + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + # Note, need to remove the first sample from the ts, and ensure that + # invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_forward_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) -class TestExactMatchViterbi(ViterbiAlgorithmBase): - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - H = G.T - # print(H) - rho = np.zeros(ts.num_sites) + 0.1 - mu = np.zeros(ts.num_sites) - rho[0] = 0 - for h in H: - p1 = ls_viterbi_naive(h, alleles, G, rho, mu) - p2 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - cm1 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - p3 = cm1.traceback() - cm2 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=False) - p4 = cm1.traceback() - self.assertCompressedMatricesEqual(cm1, cm2) - - assert len(np.unique(p1)) == 1 - assert len(np.unique(p2)) == 1 - assert len(np.unique(p3)) == 1 - assert len(np.unique(p4)) == 1 - m1 = H[p1, np.arange(H.shape[1])] - assert np.array_equal(m1, h) - m2 = H[p2, np.arange(H.shape[1])] - assert np.array_equal(m2, h) - m3 = H[p3, np.arange(H.shape[1])] - assert np.array_equal(m3, h) - m4 = H[p3, np.arange(H.shape[1])] - assert np.array_equal(m4, h) - - -@pytest.mark.slow -class TestGeneralViterbi(ViterbiAlgorithmBase, unittest.TestCase): - def verify(self, ts, alleles=tskit.ALLELES_01): - # np.set_printoptions(linewidth=20000) - # np.set_printoptions(threshold=20000000) - G = ts.genotype_matrix(alleles=alleles) - # m, n = G.shape - for h, rho, mu in self.example_parameters(ts, alleles): - # print("h = ", h) - # print("rho=", rho) - # print("mu = ", mu) - p1 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - p2 = ls_viterbi_naive(h, alleles, G, rho, mu) - cm1 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - p3 = cm1.traceback() - cm2 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=False) - p4 = cm1.traceback() - self.assertCompressedMatricesEqual(cm1, cm2) - # print() - # m1 = H[p1, np.arange(m)] - # m2 = H[p2, np.arange(m)] - # m3 = H[p3, np.arange(m)] - # count = np.unique(p1).shape[0] - # print() - # print("\tp1 = ", p1) - # print("\tp2 = ", p2) - # print("\tp3 = ", p3) - # print("\tm1 = ", m1) - # print("\tm2 = ", m2) - # print("\t h = ", h) - proba1 = ls_path_log_probability(h, p1, alleles, G, rho, mu) - proba2 = ls_path_log_probability(h, p2, alleles, G, rho, mu) - proba3 = ls_path_log_probability(h, p3, alleles, G, rho, mu) - proba4 = ls_path_log_probability(h, p4, alleles, G, rho, mu) - # print("\t P = ", proba1, proba2) - self.assertAlmostEqual(proba1, proba2, places=6) - self.assertAlmostEqual(proba1, proba3, places=6) - self.assertAlmostEqual(proba1, proba4, places=6) - - -class TestMissingHaplotypes(LiStephensBase): - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - H = G.T - - rho = np.zeros(ts.num_sites) + 0.1 - rho[0] = 0 - mu = np.zeros(ts.num_sites) + 0.001 - - # When everything is missing data we should have no recombinations. - h = H[0].copy() - h[:] = tskit.MISSING_DATA - path = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert np.all(path == 0) - cm = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - # For the tree base algorithm it's not simple which particular sample - # gets chosen. - path = cm.traceback() - assert len(set(path)) == 1 - - # TODO Not clear what else we can check about missing data. - - -class TestForwardMatrixScaling(ForwardAlgorithmBase, unittest.TestCase): - """ - Tests that we get the correct values from scaling version of the matrix - algorithm works correctly. - """ + ts_check_mirror = mirror_coordinates(ts_check) + r_flip = np.insert(np.flip(r)[:-1], 0, 0) + cm_mirror = ls_forward_tree( + np.flip(s[0, :]), ts_check_mirror, r_flip, np.flip(mu) + ) + ll_mirror_tree = np.sum(np.log10(cm_mirror.normalisation_factor)) + self.assertAllClose(ll_tree, ll_mirror_tree) + + # Ensure that the decoded matrices are the same + F_mirror_matrix, c, ll = ls.forwards( + np.flip(H, axis=0), + np.flip(s, axis=1), + r_flip, + mutation_rate=np.flip(mu), + scale_mutation_based_on_n_alleles=False, + ) - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - computed_log_proba = False - for h, rho, mu in self.example_parameters(ts, alleles): - F_unscaled = ls_forward_matrix_unscaled(h, alleles, G, rho, mu) - F, S = ls_forward_matrix(h, alleles, G, rho, mu) - column = np.atleast_2d(np.cumprod(S)).T - F_scaled = F * column - self.assertAllClose(F_scaled, F_unscaled) - log_proba1 = forward_matrix_log_proba(F, S) - psum = np.sum(F_unscaled[-1]) - # If the computed probability is close to zero, there's no point in - # computing. - if psum > 1e-20: - computed_log_proba = True - log_proba2 = np.log(psum) - self.assertAlmostEqual(log_proba1, log_proba2) - assert computed_log_proba - - -class TestForwardTree(ForwardAlgorithmBase): - """ - Tests that the tree algorithm computes the same forward matrix as the - simple method. - """ + self.assertAllClose(F_mirror_matrix, cm_mirror.decode()) + self.assertAllClose(ll, ll_tree) - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - for h, rho, mu in self.example_parameters(ts, alleles): - F, S = ls_forward_matrix(h, alleles, G, rho, mu) - cm1 = ls_forward_tree(h, alleles, ts, rho, mu, use_lib=True) - cm2 = ls_forward_tree(h, alleles, ts, rho, mu, use_lib=False) - self.assertCompressedMatricesEqual(cm1, cm2) - Ft = cm1.decode() - self.assertAllClose(S, cm1.normalisation_factor) - self.assertAllClose(F, Ft) +class TestForwardHapTree(FBAlgorithmBase): + """Tests that the tree algorithm computes the same forward matrix as the + simple method.""" -class TestAllPaths(unittest.TestCase): - """ - Tests that we compute the correct forward probablities if we sum over all - possible paths through the genotype matrix. - """ + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + for scale_mutation in [False, True]: + F, c, ll = ls.forwards( + H, + s, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=scale_mutation, + ) + # Note, need to remove the first sample from the ts, and ensure + # that invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_forward_tree( + s[0, :], + ts_check, + r, + mu, + scale_mutation_based_on_n_alleles=scale_mutation, + ) + self.assertAllClose(cm.decode(), F) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + self.assertAllClose(ll, ll_tree) - def verify(self, G, h): - m, n = G.shape - rho = np.zeros(m) + 0.1 - mu = np.zeros(m) + 0.01 - rho[0] = 0 - proba = 0 - for path in itertools.product(range(n), repeat=m): - proba += ls_path_probability(h, path, G, rho, mu) - - alleles = [["0", "1"] for _ in range(m)] - F = ls_forward_matrix_unscaled(h, alleles, G, rho, mu) - forward_proba = np.sum(F[-1]) - self.assertAlmostEqual(proba, forward_proba) - - def test_n3_m4(self): - G = np.array( - [ - # fmt: off - [1, 0, 0], - [0, 0, 1], - [1, 0, 1], - [0, 1, 1], - # fmt: on - ] - ) - self.verify(G, [0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0]) - def test_n4_m5(self): - G = np.array( - [ - # fmt: off - [1, 0, 0, 0], - [0, 0, 1, 1], - [1, 0, 1, 1], - [0, 1, 1, 0], - # fmt: on - ] - ) - self.verify(G, [0, 0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0, 0]) +class TestForwardBackwardTree(FBAlgorithmBase): + """Tests that the tree algorithm computes the same forward matrix as the + simple method.""" - def test_n5_m5(self): - G = np.zeros((5, 5), dtype=int) - np.fill_diagonal(G, 1) - self.verify(G, [0, 0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0, 0]) + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + F, c, ll = ls.forwards( + H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False + ) + B = ls.backwards( + H, + s, + c, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=False, + ) + # Note, need to remove the first sample from the ts, and ensure that + # invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + c_f = ls_forward_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(c_f.normalisation_factor)) + + ts_check_mirror = mirror_coordinates(ts_check) + r_flip = np.flip(r) + c_b = ls_backward_tree( + np.flip(s[0, :]), + ts_check_mirror, + r_flip, + np.flip(mu), + np.flip(c_f.normalisation_factor), + ) + B_tree = np.flip(c_b.decode(), axis=0) + F_tree = c_f.decode() -class TestBasicViterbi: - """ - Very simple tests of the Viterbi algorithm. - """ + self.assertAllClose(B, B_tree) + self.assertAllClose(F, F_tree) + self.assertAllClose(ll, ll_tree) - def verify_exact_match(self, G, h, path): - m, n = G.shape - rho = np.zeros(m) + 1e-9 - mu = np.zeros(m) # Set mu to zero exact match - rho[0] = 0 - alleles = [["0", "1"] for _ in range(m)] - path1 = ls_viterbi_naive(h, alleles, G, rho, mu) - path2 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert list(path1) == path - assert list(path2) == path - - def test_n2_m6_exact(self): - G = np.array( - [ - # fmt: off - [1, 0], - [1, 0], - [1, 0], - [0, 1], - [0, 1], - [0, 1], - # fmt: on - ] - ) - self.verify_exact_match(G, [1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 0], [1, 1, 1, 1, 1, 0]) - self.verify_exact_match(G, [0, 0, 0, 0, 1, 0], [1, 1, 1, 0, 1, 0]) - - def test_n3_m6_exact(self): - G = np.array( - [ - # fmt: off - [1, 0, 1], - [1, 0, 0], - [1, 0, 1], - [0, 1, 0], - [0, 1, 1], - [0, 1, 0], - # fmt: on - ] - ) - self.verify_exact_match(G, [1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]) - self.verify_exact_match(G, [1, 0, 1, 0, 1, 0], [2, 2, 2, 2, 2, 2]) - def test_n3_m6(self): - G = np.array( - [ - # fmt: off - [1, 0, 1], - [1, 0, 0], - [1, 0, 1], - [0, 1, 0], - [0, 1, 1], - [0, 1, 0], - # fmt: on - ] - ) +class TestTreeViterbiHap(VitAlgorithmBase): + """Test that we have the same log-likelihood between tree and matrix + implementations""" - m, n = G.shape - rho = np.zeros(m) + 1e-2 - mu = np.zeros(m) - rho[0] = 0 - alleles = [["0", "1"] for _ in range(m)] - h = np.ones(m, dtype=int) - path1 = ls_viterbi_naive(h, alleles, G, rho, mu) - - # Add in mutation at a very low rate. - mu[:] = 1e-8 - path2 = ls_viterbi_naive(h, alleles, G, rho, mu) - path3 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert np.array_equal(path1, path2) - assert np.array_equal(path2, path3) + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + path, ll = ls.viterbi( + H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False + ) + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_viterbi_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + self.assertAllClose(ll, ll_tree) + + # Now, need to ensure that the likelihood of the preferred path is + # the same as ll_tree (and ll). + path_tree = cm.traceback() + ll_check = ls.path_ll( + H, + s, + path_tree, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=False, + ) + self.assertAllClose(ll, ll_check) From 67f2ef35f4a0acfc1fbe3bb1f0eeb94bdcf0bc4c Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sat, 15 Jul 2023 19:37:50 +0100 Subject: [PATCH 64/84] Clarify tree.mrca is MRCA does not exist --- python/tskit/trees.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 10b3bd9782..6c2fbb05a5 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -997,7 +997,9 @@ def mrca(self, *args): Returns the most recent common ancestor of the specified nodes. :param int `*args`: input node IDs, must be at least 2. - :return: The most recent common ancestor of input nodes. + :return: The node ID of the most recent common ancestor of the + input nodes, or :data:`tskit.NULL` if the nodes do not share + a common ancestor in the tree. :rtype: int """ if len(args) < 2: From f93e79254b84e095ae13b80b6a47a27dcc907cb3 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 7 Jul 2023 16:40:44 +0100 Subject: [PATCH 65/84] Initial version of tsk_tree_position_t and tests --- c/tests/test_trees.c | 153 +++++++++ c/tskit/trees.c | 158 +++++++++ c/tskit/trees.h | 34 ++ python/tests/test_tree_positioning.py | 470 ++++++++++++++++++++++++++ python/tests/tsutil.py | 192 +++++++++++ 5 files changed, 1007 insertions(+) create mode 100644 python/tests/test_tree_positioning.py diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index cceb11d6fd..63b7292322 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -175,6 +175,97 @@ verify_individual_nodes(tsk_treeseq_t *ts) } } +static void +verify_tree_pos(const tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *tree_parents) +{ + int ret; + const tsk_size_t N = tsk_treeseq_get_num_nodes(ts); + const tsk_id_t *edges_parent = ts->tables->edges.parent; + const tsk_id_t *edges_child = ts->tables->edges.child; + tsk_tree_position_t tree_pos; + tsk_id_t *known_parent; + tsk_id_t *parent = tsk_malloc(N * sizeof(*parent)); + tsk_id_t u, index, j, e; + bool valid; + + CU_ASSERT_FATAL(parent != NULL); + + ret = tsk_tree_position_init(&tree_pos, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (u = 0; u < (tsk_id_t) N; u++) { + parent[u] = TSK_NULL; + } + + for (index = 0; index < (tsk_id_t) num_trees; index++) { + known_parent = tree_parents + N * (tsk_size_t) index; + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_TRUE(valid); + CU_ASSERT_EQUAL(index, tree_pos.index); + + for (j = tree_pos.out.start; j < tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + + for (j = tree_pos.in.start; j < tree_pos.in.stop; j++) { + e = tree_pos.in.order[j]; + parent[edges_child[e]] = edges_parent[e]; + } + + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], known_parent[u]); + } + } + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FALSE(valid); + for (j = tree_pos.out.start; j < tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], TSK_NULL); + } + + for (index = (tsk_id_t) num_trees - 1; index >= 0; index--) { + known_parent = tree_parents + N * (tsk_size_t) index; + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_TRUE(valid); + CU_ASSERT_EQUAL(index, tree_pos.index); + + for (j = tree_pos.out.start; j > tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + + for (j = tree_pos.in.start; j > tree_pos.in.stop; j--) { + CU_ASSERT_FATAL(j >= 0); + e = tree_pos.in.order[j]; + parent[edges_child[e]] = edges_parent[e]; + } + + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], known_parent[u]); + } + } + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FALSE(valid); + for (j = tree_pos.out.start; j > tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], TSK_NULL); + } + + tsk_tree_position_free(&tree_pos); + tsk_safe_free(parent); +} + static void verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) { @@ -233,6 +324,8 @@ verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(ts), breakpoints[j]); tsk_tree_free(&tree); + + verify_tree_pos(ts, num_trees, parents); } static tsk_tree_t * @@ -5233,6 +5326,65 @@ test_single_tree_tracked_samples(void) tsk_tree_free(&tree); } +static void +test_single_tree_tree_pos(void) +{ + tsk_treeseq_t ts; + tsk_tree_position_t tree_pos; + bool valid; + int ret; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + + ret = tsk_tree_position_init(&tree_pos, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FATAL(valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.left, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.right, 1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.stop, 6); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.order, ts.tables->indexes.edge_insertion_order); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_removal_order); + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FATAL(!valid); + + tsk_tree_position_print_state(&tree_pos, _devnull); + + CU_ASSERT_EQUAL_FATAL(tree_pos.index, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 6); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_removal_order); + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FATAL(valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.left, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.right, 1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.stop, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.order, ts.tables->indexes.edge_removal_order); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_insertion_order); + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FATAL(!valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.index, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_insertion_order); + + tsk_tree_position_free(&tree_pos); + tsk_treeseq_free(&ts); +} + /*======================================================= * Multi tree tests. *======================================================*/ @@ -8185,6 +8337,7 @@ main(int argc, char **argv) { "test_single_tree_map_mutations_internal_samples", test_single_tree_map_mutations_internal_samples }, { "test_single_tree_tracked_samples", test_single_tree_tracked_samples }, + { "test_single_tree_tree_pos", test_single_tree_tree_pos }, /* Multi tree tests */ { "test_simple_multi_tree", test_simple_multi_tree }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index dac3ac154b..bce3b16b9d 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3553,6 +3553,164 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag return ret; } +/* ======================================================== * + * tree_position + * ======================================================== */ + +static void +tsk_tree_position_set_null(tsk_tree_position_t *self) +{ + self->index = -1; + self->interval.left = 0; + self->interval.right = 0; +} + +int +tsk_tree_position_init(tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, + tsk_flags_t TSK_UNUSED(options)) +{ + memset(self, 0, sizeof(*self)); + self->tree_sequence = tree_sequence; + tsk_tree_position_set_null(self); + return 0; +} + +int +tsk_tree_position_free(tsk_tree_position_t *TSK_UNUSED(self)) +{ + return 0; +} + +int +tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out) +{ + fprintf(out, "Tree position state\n"); + fprintf(out, "index = %d\n", (int) self->index); + fprintf( + out, "out = start=%d\tstop=%d\n", (int) self->out.start, (int) self->out.stop); + fprintf( + out, "in = start=%d\tstop=%d\n", (int) self->in.start, (int) self->in.stop); + return 0; +} + +bool +tsk_tree_position_next(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double left; + + if (self->index == -1) { + self->interval.right = 0; + self->in.stop = 0; + self->out.stop = 0; + self->direction = TSK_DIR_FORWARD; + } + + if (self->direction == TSK_DIR_FORWARD) { + left_current_index = self->in.stop; + right_current_index = self->out.stop; + } else { + left_current_index = self->out.stop + 1; + right_current_index = self->in.stop + 1; + } + + left = self->interval.right; + + j = right_current_index; + self->out.start = j; + while (j < M && right_coords[right_order[j]] == left) { + j++; + } + self->out.stop = j; + self->out.order = right_order; + + j = left_current_index; + self->in.start = j; + while (j < M && left_coords[left_order[j]] == left) { + j++; + } + self->in.stop = j; + self->in.order = left_order; + + self->direction = TSK_DIR_FORWARD; + self->index++; + if (self->index == num_trees) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = left; + self->interval.right = breakpoints[self->index + 1]; + } + return self->index != -1; +} + +bool +tsk_tree_position_prev(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const double sequence_length = tables->sequence_length; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double right; + + if (self->index == -1) { + self->index = num_trees; + self->interval.left = sequence_length; + self->in.stop = M - 1; + self->out.stop = M - 1; + self->direction = TSK_DIR_REVERSE; + } + + if (self->direction == TSK_DIR_REVERSE) { + left_current_index = self->out.stop; + right_current_index = self->in.stop; + } else { + left_current_index = self->in.stop - 1; + right_current_index = self->out.stop - 1; + } + + right = self->interval.left; + + j = left_current_index; + self->out.start = j; + while (j >= 0 && left_coords[left_order[j]] == right) { + j--; + } + self->out.stop = j; + self->out.order = left_order; + + j = right_current_index; + self->in.start = j; + while (j >= 0 && right_coords[right_order[j]] == right) { + j--; + } + self->in.stop = j; + self->in.order = right_order; + + self->index--; + self->direction = TSK_DIR_REVERSE; + if (self->index == -1) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = breakpoints[self->index]; + self->interval.right = right; + } + return self->index != -1; +} + /* ======================================================== * * Tree * ======================================================== */ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index b36a38c31f..95c66a6ac7 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1739,6 +1739,40 @@ bool tsk_tree_equals(const tsk_tree_t *self, const tsk_tree_t *other); int tsk_diff_iter_init_from_ts( tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); +/* Temporarily putting this here to avoid problems with doxygen. Will need to + * move up the file later when it gets incorporated into the tsk_tree_t object. + */ +typedef struct { + tsk_id_t index; + struct { + double left; + double right; + } interval; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } in; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } out; + tsk_id_t left_current_index; + tsk_id_t right_current_index; + int direction; + const tsk_treeseq_t *tree_sequence; +} tsk_tree_position_t; + +int tsk_tree_position_init( + tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); +int tsk_tree_position_free(tsk_tree_position_t *self); +int tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out); +bool tsk_tree_position_next(tsk_tree_position_t *self); +bool tsk_tree_position_prev(tsk_tree_position_t *self); +int tsk_tree_position_seek_forward(tsk_tree_position_t *self, tsk_id_t index); +int tsk_tree_position_seek_backward(tsk_tree_position_t *self, tsk_id_t index); + #ifdef __cplusplus } #endif diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py new file mode 100644 index 0000000000..961f0810f7 --- /dev/null +++ b/python/tests/test_tree_positioning.py @@ -0,0 +1,470 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tests for tree iterator schemes. Mostly used to develop the incremental +iterator infrastructure. +""" +import msprime +import numpy as np +import pytest + +import tests +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +class StatefulTree: + """ + Just enough functionality to mimic the low-level tree implementation + for testing of forward/backward moving. + """ + + def __init__(self, ts): + self.ts = ts + self.tree_pos = tsutil.TreePosition(ts) + self.parent = [-1 for _ in range(ts.num_nodes)] + + def __str__(self): + s = f"parent: {self.parent}\nposition:\n" + for line in str(self.tree_pos).splitlines(): + s += f"\t{line}\n" + return s + + def assert_equal(self, other): + assert self.parent == other.parent + assert self.tree_pos.index == other.tree_pos.index + assert self.tree_pos.interval == other.tree_pos.interval + + def next(self): # NOQA: A003 + valid = self.tree_pos.next() + if valid: + for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): + e = self.tree_pos.out_range.order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop): + e = self.tree_pos.in_range.order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + return valid + + def prev(self): + valid = self.tree_pos.prev() + if valid: + for j in range( + self.tree_pos.out_range.start, self.tree_pos.out_range.stop, -1 + ): + e = self.tree_pos.out_range.order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range( + self.tree_pos.in_range.start, self.tree_pos.in_range.stop, -1 + ): + e = self.tree_pos.in_range.order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + return valid + + def iter_forward(self, index): + while self.tree_pos.index != index: + self.next() + + def seek_forward(self, index): + old_left, old_right = self.tree_pos.interval + self.tree_pos.seek_forward(index) + left, right = self.tree_pos.interval + # print() + # print("Current interval:", old_left, old_right) + # print("New interval:", left, right) + # print("index:", index, "out_range:", self.tree_pos.out_range) + for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): + e = self.tree_pos.out_range.order[j] + e_left = self.ts.edges_left[e] + # We only need to remove an edge if it's in the current tree, which + # can only happen if the edge's left coord is < the current tree's + # right coordinate. + if e_left < old_right: + c = self.ts.edges_child[e] + assert self.parent[c] != -1 + self.parent[c] = -1 + assert e_left < left + # print("index:", index, "in_range:", self.tree_pos.in_range) + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop): + e = self.tree_pos.in_range.order[j] + if self.ts.edges_left[e] <= left < self.ts.edges_right[e]: + # print("keep", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) + # print( + # "INSERT:", + # self.ts.edge(e), + # self.ts.nodes_time[self.ts.edges_parent[e]], + # ) + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + else: + a = self.tree_pos.in_range.start + b = self.tree_pos.in_range.stop + # The first and last indexes in the range should always be valid + # for the tree. + assert a < j < b - 1 + # print("skip", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) + + def seek_backward(self, index): + # TODO + while self.tree_pos.index != index: + self.prev() + + def iter_backward(self, index): + while self.tree_pos.index != index: + self.prev() + + +def check_iters_forward(ts): + alg_t_output = tsutil.algorithm_T(ts) + lib_tree = tskit.Tree(ts) + tree_pos = tsutil.TreePosition(ts) + sample_count = np.zeros(ts.num_nodes, dtype=int) + sample_count[ts.samples()] = 1 + parent1 = [-1 for _ in range(ts.num_nodes)] + i = 0 + lib_tree.next() + while tree_pos.next(): + out_times = [] + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop): + e = tree_pos.out_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + out_times.append(ts.nodes_time[p]) + parent1[c] = -1 + in_times = [] + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop): + e = tree_pos.in_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + in_times.append(ts.nodes_time[p]) + parent1[c] = p + # We must visit the edges in *increasing* time order on the way in, + # and *decreasing* order on the way out. Otherwise we get quadratic + # behaviour for algorithms that need to propagate changes up to the + # root. + assert out_times == sorted(out_times, reverse=True) + assert in_times == sorted(in_times) + + interval, parent2 = next(alg_t_output) + assert list(interval) == list(tree_pos.interval) + assert parent1 == parent2 + + assert lib_tree.index == i + assert list(lib_tree.interval) == list(interval) + assert list(lib_tree.parent_array[:-1]) == parent1 + + lib_tree.next() + i += 1 + assert i == ts.num_trees + assert lib_tree.index == -1 + assert next(alg_t_output, None) is None + + +def check_iters_back(ts): + alg_t_output = [ + (list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts) + ] + i = len(alg_t_output) - 1 + + lib_tree = tskit.Tree(ts) + tree_pos = tsutil.TreePosition(ts) + parent1 = [-1 for _ in range(ts.num_nodes)] + + lib_tree.last() + + while tree_pos.prev(): + # print(tree_pos.out_range) + out_times = [] + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, -1): + e = tree_pos.out_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + out_times.append(ts.nodes_time[p]) + parent1[c] = -1 + in_times = [] + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, -1): + e = tree_pos.in_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + in_times.append(ts.nodes_time[p]) + parent1[c] = p + + # We must visit the edges in *increasing* time order on the way in, + # and *decreasing* order on the way out. Otherwise we get quadratic + # behaviour for algorithms that need to propagate changes up to the + # root. + assert out_times == sorted(out_times, reverse=True) + assert in_times == sorted(in_times) + + interval, parent2 = alg_t_output[i] + assert list(interval) == list(tree_pos.interval) + assert parent1 == parent2 + + assert lib_tree.index == i + assert list(lib_tree.interval) == list(interval) + assert list(lib_tree.parent_array[:-1]) == parent1 + + lib_tree.prev() + i -= 1 + + assert lib_tree.index == -1 + assert i == -1 + + +def check_forward_back_sweep(ts): + alg_t_output = [ + (list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts) + ] + for j in range(ts.num_trees - 1): + tree = StatefulTree(ts) + # Seek forward to j + k = 0 + while k <= j: + tree.next() + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert list(tree.tree_pos.interval) == interval + assert parent == tree.parent + k += 1 + k = j + # And back to zero + while k >= 0: + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert list(tree.tree_pos.interval) == interval + assert parent == tree.parent + tree.prev() + k -= 1 + + +class TestDirectionSwitching: + # 2.00┊ ┊ 4 ┊ 4 ┊ 4 ┊ + # ┊ ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ ┏┻━┓ ┊ + # 1.00┊ 3 ┊ ┃ 3 ┊ 3 ┃ ┊ 3 ┃ ┊ + # ┊ ┏━╋━┓ ┊ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 2 ┊ 0 1 2 ┊ 0 2 1 ┊ 0 1 2 ┊ + # 0 1 2 3 4 + # index 0 1 2 3 + def ts(self): + return tsutil.all_trees_ts(3) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_to_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_from_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_backward_to_next(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_backward(index) + tree1.next() + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index + 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index + 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_next_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_next_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.seek_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_from_null(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + def test_seek_forward_next_null(self): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(3) + tree1.next() + assert tree1.tree_pos.index == -1 + assert list(tree1.tree_pos.interval) == [0, 0] + + +class TestSeeking: + @tests.cached_example + def ts(self): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + return ts + + @pytest.mark.parametrize("index", range(26)) + def test_seek_forward_from_null(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", range(1, 26)) + def test_seek_forward_from_first(self, index): + tree1 = StatefulTree(self.ts()) + tree1.next() + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", range(1, 26)) + def test_seek_last_from_index(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.iter_forward(index) + tree1.seek_forward(ts.num_trees - 1) + tree2 = StatefulTree(ts) + tree2.prev() + tree1.assert_equal(tree2) + + +class TestAllTreesTs: + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_forward_full(self, n): + ts = tsutil.all_trees_ts(n) + check_iters_forward(ts) + + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_back_full(self, n): + ts = tsutil.all_trees_ts(n) + check_iters_back(ts) + + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_forward_back(self, n): + ts = tsutil.all_trees_ts(n) + check_forward_back_sweep(ts) + + +class TestManyTreesSimulationExample: + @tests.cached_example + def ts(self): + ts = msprime.sim_ancestry( + 10, sequence_length=1000, recombination_rate=0.1, random_seed=1234 + ) + assert ts.num_trees > 250 + return ts + + @pytest.mark.parametrize("index", [1, 5, 10, 50, 100]) + def test_seek_forward_from_null(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("num_trees", [1, 5, 10, 50, 100]) + def test_seek_forward_from_mid(self, num_trees): + ts = self.ts() + start_index = ts.num_trees // 2 + dest_index = min(start_index + num_trees, ts.num_trees - 1) + tree1 = StatefulTree(ts) + tree1.iter_forward(start_index) + tree1.seek_forward(dest_index) + tree2 = StatefulTree(ts) + tree2.iter_forward(dest_index) + tree1.assert_equal(tree2) + + def test_forward_full(self): + check_iters_forward(self.ts()) + + def test_back_full(self): + check_iters_back(self.ts()) + + +class TestSuiteExamples: + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_forward_full(self, ts): + check_iters_forward(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_back_full(self, ts): + check_iters_back(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_forward_from_null(self, ts): + index = ts.num_trees // 2 + tree1 = StatefulTree(ts) + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_forward_from_first(self, ts): + index = ts.num_trees - 1 + tree1 = StatefulTree(ts) + tree1.next() + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 34334e9be0..b86a159274 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -24,11 +24,13 @@ A collection of utilities to edit and construct tree sequences. """ import collections +import dataclasses import functools import json import random import string import struct +import typing import msprime import numpy as np @@ -1713,6 +1715,196 @@ def iterate(self): left = right +FORWARD = 1 +REVERSE = -1 + + +@dataclasses.dataclass +class Interval: + left: float + right: float + + def __iter__(self): + yield self.left + yield self.right + + +@dataclasses.dataclass +class EdgeRange: + start: int + stop: int + order: typing.List + + +class TreePosition: + def __init__(self, ts): + self.ts = ts + self.index = -1 + self.direction = 0 + self.interval = Interval(0, 0) + self.in_range = EdgeRange(0, 0, None) + self.out_range = EdgeRange(0, 0, None) + + def __str__(self): + s = f"index: {self.index}\ninterval: {self.interval}\n" + s += f"direction: {self.direction}\n" + s += f"in_range: {self.in_range}\n" + s += f"out_range: {self.out_range}\n" + return s + + def set_null(self): + self.index = -1 + self.interval.left = 0 + self.interval.right = 0 + + def next(self): # NOQA: A003 + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + self.interval.right = 0 + self.out_range.stop = 0 + self.in_range.stop = 0 + self.direction = FORWARD + + if self.direction == FORWARD: + left_current_index = self.in_range.stop + right_current_index = self.out_range.stop + else: + left_current_index = self.out_range.stop + 1 + right_current_index = self.in_range.stop + 1 + + left = self.interval.right + + j = right_current_index + self.out_range.start = j + while j < M and right_coords[right_order[j]] == left: + j += 1 + self.out_range.stop = j + self.out_range.order = right_order + + j = left_current_index + self.in_range.start = j + while j < M and left_coords[left_order[j]] == left: + j += 1 + self.in_range.stop = j + self.in_range.order = left_order + + self.direction = FORWARD + self.index += 1 + if self.index == self.ts.num_trees: + self.set_null() + else: + self.interval.left = left + self.interval.right = breakpoints[self.index + 1] + return self.index != -1 + + def prev(self): + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + + if self.index == -1: + self.index = self.ts.num_trees + self.interval.left = self.ts.sequence_length + self.in_range.stop = M - 1 + self.out_range.stop = M - 1 + self.direction = REVERSE + + if self.direction == REVERSE: + left_current_index = self.out_range.stop + right_current_index = self.in_range.stop + else: + left_current_index = self.in_range.stop - 1 + right_current_index = self.out_range.stop - 1 + + right = self.interval.left + + j = left_current_index + self.out_range.start = j + while j >= 0 and left_coords[left_order[j]] == right: + j -= 1 + self.out_range.stop = j + self.out_range.order = left_order + + j = right_current_index + self.in_range.start = j + while j >= 0 and right_coords[right_order[j]] == right: + j -= 1 + self.in_range.stop = j + self.in_range.order = right_order + + self.direction = REVERSE + self.index -= 1 + if self.index == -1: + self.set_null() + else: + self.interval.left = breakpoints[self.index] + self.interval.right = right + return self.index != -1 + + def seek_forward(self, index): + # NOTE this is still in development and not fully tested. + assert index >= self.index and index < self.ts.num_trees + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + self.interval.right = 0 + self.out_range.stop = 0 + self.in_range.stop = 0 + self.direction = FORWARD + + if self.direction == FORWARD: + left_current_index = self.in_range.stop + right_current_index = self.out_range.stop + else: + left_current_index = self.out_range.stop + 1 + right_current_index = self.in_range.stop + 1 + + self.direction = FORWARD + left = breakpoints[index] + + # The range of edges we need consider for removal starts + # at the current right index and ends at the first edge + # where the right coordinate is equal to the new tree's + # left coordinate. + j = right_current_index + self.out_range.start = j + # TODO This could be done with binary search + while j < M and right_coords[right_order[j]] <= left: + j += 1 + self.out_range.stop = j + + # The range of edges we need to consider for the new tree + # must have right coordinate > left + j = left_current_index + while j < M and right_coords[left_order[j]] <= left: + j += 1 + self.in_range.start = j + # TODO this could be done with a binary search + while j < M and left_coords[left_order[j]] <= left: + j += 1 + self.in_range.stop = j + + self.interval.left = left + self.interval.right = breakpoints[index + 1] + self.out_range.order = right_order + self.in_range.order = left_order + self.index = index + + def mean_descendants(ts, reference_sets): """ Returns the mean number of nodes from the specified reference sets From 63b317cf24c0a3a1ee0b0ce358af6a2e77eead0e Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 12 Jul 2023 15:06:48 +0100 Subject: [PATCH 66/84] Convert kc_distance to use tree_position_t --- c/tskit/trees.c | 106 ++++++++++++++++++++++++------------------------ 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index bce3b16b9d..8a3d0afc95 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6104,25 +6104,29 @@ update_kc_subtree_state( } static int -update_kc_incremental(tsk_tree_t *self, kc_vectors *kc, tsk_edge_list_t *edges_out, - tsk_edge_list_t *edges_in, tsk_size_t *depths) +update_kc_incremental( + tsk_tree_t *tree, kc_vectors *kc, tsk_tree_position_t *tree_pos, tsk_size_t *depths) { int ret = 0; - tsk_edge_list_node_t *record; - tsk_edge_t *e; - tsk_id_t u; + tsk_id_t u, v, e, j; double root_time, time; - const double *times = self->tree_sequence->tables->nodes.time; + const double *restrict times = tree->tree_sequence->tables->nodes.time; + const tsk_id_t *restrict edges_child = tree->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = tree->tree_sequence->tables->edges.parent; + + tsk_bug_assert(tree_pos->index == tree->index); + tsk_bug_assert(tree_pos->interval.left == tree->interval.left); + tsk_bug_assert(tree_pos->interval.right == tree->interval.right); /* Update state of detached subtrees */ - for (record = edges_out->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; + for (j = tree_pos->out.stop - 1; j >= tree_pos->out.start; j--) { + e = tree_pos->out.order[j]; + u = edges_child[e]; depths[u] = 0; - if (self->parent[u] == TSK_NULL) { - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); + if (tree->parent[u] == TSK_NULL) { + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); if (ret != 0) { goto out; } @@ -6130,25 +6134,25 @@ update_kc_incremental(tsk_tree_t *self, kc_vectors *kc, tsk_edge_list_t *edges_o } /* Propagate state change down into reattached subtrees. */ - for (record = edges_in->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; + for (j = tree_pos->in.stop - 1; j >= tree_pos->in.start; j--) { + e = tree_pos->in.order[j]; + u = edges_child[e]; + v = edges_parent[e]; - tsk_bug_assert(depths[e->child] == 0); - depths[u] = depths[e->parent] + 1; + tsk_bug_assert(depths[u] == 0); + depths[u] = depths[v] + 1; - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); if (ret != 0) { goto out; } - if (tsk_tree_is_sample(self, u)) { - time = tsk_tree_get_branch_length_unsafe(self, u); - update_kc_vectors_single_sample(self->tree_sequence, kc, u, time); + if (tsk_tree_is_sample(tree, u)) { + time = tsk_tree_get_branch_length_unsafe(tree, u); + update_kc_vectors_single_sample(tree->tree_sequence, kc, u, time); } } - out: return ret; } @@ -6164,19 +6168,18 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, const tsk_treeseq_t *treeseqs[2] = { self, other }; tsk_tree_t trees[2]; kc_vectors kcs[2]; - tsk_diff_iter_t diff_iters[2]; - tsk_edge_list_t edges_out[2]; - tsk_edge_list_t edges_in[2]; + /* TODO the tree_pos here is redundant because we should be using this interally + * in the trees to do the advancing. Once we have converted the tree over to using + * tree_pos internally, we can get rid of these tree_pos variables and use + * the values stored in the trees themselves */ + tsk_tree_position_t tree_pos[2]; tsk_size_t *depths[2]; - double t0_left, t0_right, t1_left, t1_right; int ret = 0; for (i = 0; i < 2; i++) { tsk_memset(&trees[i], 0, sizeof(trees[i])); - tsk_memset(&diff_iters[i], 0, sizeof(diff_iters[i])); + tsk_memset(&tree_pos[i], 0, sizeof(tree_pos[i])); tsk_memset(&kcs[i], 0, sizeof(kcs[i])); - tsk_memset(&edges_out[i], 0, sizeof(edges_out[i])); - tsk_memset(&edges_in[i], 0, sizeof(edges_in[i])); depths[i] = NULL; } @@ -6191,7 +6194,7 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_init_from_ts(&diff_iters[i], treeseqs[i], false); + ret = tsk_tree_position_init(&tree_pos[i], treeseqs[i], 0); if (ret != 0) { goto out; } @@ -6218,11 +6221,10 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); + tsk_tree_position_next(&tree_pos[0]); + tsk_bug_assert(tree_pos[0].index == 0); + + ret = update_kc_incremental(&trees[0], &kcs[0], &tree_pos[0], depths[0]); if (ret != 0) { goto out; } @@ -6231,37 +6233,37 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[1], &t1_left, &t1_right, &edges_out[1], &edges_in[1]); - tsk_bug_assert(ret == TSK_TREE_OK); + tsk_tree_position_next(&tree_pos[1]); + tsk_bug_assert(tree_pos[1].index != -1); - ret = update_kc_incremental( - &trees[1], &kcs[1], &edges_out[1], &edges_in[1], depths[1]); + ret = update_kc_incremental(&trees[1], &kcs[1], &tree_pos[1], depths[1]); if (ret != 0) { goto out; } - while (t0_right < t1_right) { - span = t0_right - left; + tsk_bug_assert(trees[0].interval.left == tree_pos[0].interval.left); + tsk_bug_assert(trees[0].interval.right == tree_pos[0].interval.right); + tsk_bug_assert(trees[1].interval.left == tree_pos[1].interval.left); + tsk_bug_assert(trees[1].interval.right == tree_pos[1].interval.right); + while (trees[0].interval.right < trees[1].interval.right) { + span = trees[0].interval.right - left; total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; - left = t0_right; + left = trees[0].interval.right; ret = tsk_tree_next(&trees[0]); tsk_bug_assert(ret == TSK_TREE_OK); ret = check_kc_distance_tree_inputs(&trees[0]); if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); + tsk_tree_position_next(&tree_pos[0]); + tsk_bug_assert(tree_pos[0].index != -1); + ret = update_kc_incremental(&trees[0], &kcs[0], &tree_pos[0], depths[0]); if (ret != 0) { goto out; } } - span = t1_right - left; - left = t1_right; + span = trees[1].interval.right - left; + left = trees[1].interval.right; total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; } if (ret != 0) { @@ -6272,7 +6274,7 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, out: for (i = 0; i < 2; i++) { tsk_tree_free(&trees[i]); - tsk_diff_iter_free(&diff_iters[i]); + tsk_tree_position_free(&tree_pos[i]); kc_vectors_free(&kcs[i]); tsk_safe_free(depths[i]); } From 2c264be71ecbc2293d051d1db04bf43500b70edf Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 12 Jul 2023 16:29:13 +0100 Subject: [PATCH 67/84] Convert LS HMM code to use tree_position_t --- c/tskit/haplotype_matching.c | 63 ++++++++++++++++++------------------ c/tskit/haplotype_matching.h | 7 ++-- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/c/tskit/haplotype_matching.c b/c/tskit/haplotype_matching.c index b942da18d6..d6fdfd7f46 100644 --- a/c/tskit/haplotype_matching.c +++ b/c/tskit/haplotype_matching.c @@ -209,7 +209,7 @@ int tsk_ls_hmm_free(tsk_ls_hmm_t *self) { tsk_tree_free(&self->tree); - tsk_diff_iter_free(&self->diffs); + tsk_tree_position_free(&self->tree_pos); tsk_safe_free(self->recombination_rate); tsk_safe_free(self->mutation_rate); tsk_safe_free(self->recombination_rate); @@ -248,9 +248,8 @@ tsk_ls_hmm_reset(tsk_ls_hmm_t *self) tsk_memset(self->transition_parent, 0xff, self->max_transitions * sizeof(*self->transition_parent)); - /* This is safe because we've already zero'd out the memory. */ - tsk_diff_iter_free(&self->diffs); - ret = tsk_diff_iter_init_from_ts(&self->diffs, self->tree_sequence, false); + tsk_tree_position_free(&self->tree_pos); + ret = tsk_tree_position_init(&self->tree_pos, self->tree_sequence, 0); if (ret != 0) { goto out; } @@ -306,21 +305,20 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) int ret = 0; tsk_id_t *restrict parent = self->parent; tsk_id_t *restrict T_index = self->transition_index; + const tsk_id_t *restrict edges_child = self->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = self->tree_sequence->tables->edges.parent; tsk_value_transition_t *restrict T = self->transitions; - tsk_edge_list_node_t *record; - tsk_edge_list_t records_out, records_in; - tsk_edge_t edge; - double left, right; - tsk_id_t u; + tsk_id_t u, c, p, j, e; tsk_value_transition_t *vt; - ret = tsk_diff_iter_next(&self->diffs, &left, &right, &records_out, &records_in); - if (ret < 0) { - goto out; - } + tsk_tree_position_next(&self->tree_pos); + tsk_bug_assert(self->tree_pos.index != -1); + tsk_bug_assert(self->tree_pos.index == self->tree.index); - for (record = records_out.head; record != NULL; record = record->next) { - u = record->edge.child; + for (j = self->tree_pos.out.start; j < self->tree_pos.out.stop; j++) { + e = self->tree_pos.out.order[j]; + c = edges_child[e]; + u = c; if (T_index[u] == TSK_NULL) { /* Ensure the subtree we're detaching has a transition at the root */ while (T_index[u] == TSK_NULL) { @@ -328,25 +326,27 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) tsk_bug_assert(u != TSK_NULL); } tsk_bug_assert(self->num_transitions < self->max_transitions); - T_index[record->edge.child] = (tsk_id_t) self->num_transitions; - T[self->num_transitions].tree_node = record->edge.child; + T_index[c] = (tsk_id_t) self->num_transitions; + T[self->num_transitions].tree_node = c; T[self->num_transitions].value = T[T_index[u]].value; self->num_transitions++; } - parent[record->edge.child] = TSK_NULL; + parent[c] = TSK_NULL; } - for (record = records_in.head; record != NULL; record = record->next) { - edge = record->edge; - parent[edge.child] = edge.parent; - u = edge.parent; - if (parent[edge.parent] == TSK_NULL) { + for (j = self->tree_pos.in.start; j < self->tree_pos.in.stop; j++) { + e = self->tree_pos.in.order[j]; + c = edges_child[e]; + p = edges_parent[e]; + parent[c] = p; + u = p; + if (parent[p] == TSK_NULL) { /* Grafting onto a new root. */ - if (T_index[record->edge.parent] == TSK_NULL) { - T_index[edge.parent] = (tsk_id_t) self->num_transitions; + if (T_index[p] == TSK_NULL) { + T_index[p] = (tsk_id_t) self->num_transitions; tsk_bug_assert(self->num_transitions < self->max_transitions); - T[self->num_transitions].tree_node = edge.parent; - T[self->num_transitions].value = T[T_index[edge.child]].value; + T[self->num_transitions].tree_node = p; + T[self->num_transitions].value = T[T_index[c]].value; self->num_transitions++; } } else { @@ -356,18 +356,17 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) } tsk_bug_assert(u != TSK_NULL); } - tsk_bug_assert(T_index[u] != -1 && T_index[edge.child] != -1); - if (T[T_index[u]].value == T[T_index[edge.child]].value) { - vt = &T[T_index[edge.child]]; + tsk_bug_assert(T_index[u] != -1 && T_index[c] != -1); + if (T[T_index[u]].value == T[T_index[c]].value) { + vt = &T[T_index[c]]; /* Mark the value transition as unusued */ vt->value = -1; vt->tree_node = TSK_NULL; - T_index[edge.child] = TSK_NULL; + T_index[c] = TSK_NULL; } } ret = tsk_ls_hmm_remove_dead_roots(self); -out: return ret; } diff --git a/c/tskit/haplotype_matching.h b/c/tskit/haplotype_matching.h index 46631fb086..e4d82bef03 100644 --- a/c/tskit/haplotype_matching.h +++ b/c/tskit/haplotype_matching.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -98,7 +98,10 @@ typedef struct _tsk_ls_hmm_t { tsk_size_t num_nodes; /* state */ tsk_tree_t tree; - tsk_diff_iter_t diffs; + /* NOTE: this tree_position will be redundant once we integrate the top-level + * tree class with this. + */ + tsk_tree_position_t tree_pos; tsk_id_t *parent; /* The probability value transitions on the tree */ tsk_value_transition_t *transitions; From 1ebf6190633ee3794117a409fedfb560beb2a8e2 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sat, 29 Jul 2023 00:33:36 +0100 Subject: [PATCH 68/84] Correct doc instances of "raises" (#2807) As per sphinx docs --- python/tskit/intervals.py | 2 +- python/tskit/provenance.py | 4 ++-- python/tskit/trees.py | 17 ++++++++--------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/python/tskit/intervals.py b/python/tskit/intervals.py index ef371531e0..0c78c50b5b 100644 --- a/python/tskit/intervals.py +++ b/python/tskit/intervals.py @@ -265,7 +265,7 @@ def find_index(self, x: float) -> int: :param float x: The position to search. :return: The index of the interval containing this point. :rtype: int - :raises: KeyError if the position is not contained in any of the intervals. + :raises KeyError: if the position is not contained in any of the intervals. """ if x < 0 or x >= self.sequence_length: raise KeyError(f"Position {x} out of bounds") diff --git a/python/tskit/provenance.py b/python/tskit/provenance.py index 82fb19518a..bc88e29f1a 100644 --- a/python/tskit/provenance.py +++ b/python/tskit/provenance.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2020 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -117,7 +117,7 @@ def validate_provenance(provenance): :param dict provenance: The dictionary representing a JSON document to be validated against the schema. - :raises: :class:`tskit.ProvenanceValidationError` + :raises ProvenanceValidationError: if the schema is not valid. """ schema = get_schema() try: diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 6c2fbb05a5..177f357306 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -878,7 +878,7 @@ def unrank(num_leaves, rank, *, span=1, branch_length=1) -> Tree: from which the tree is taken will have its :attr:`~tskit.TreeSequence.sequence_length` equal to ``span``. :param: float branch_length: The minimum length of a branch in this tree. - :raises: ValueError: If the given rank is out of bounds for trees + :raises ValueError: If the given rank is out of bounds for trees with ``num_leaves`` leaves. """ rank_tree = combinatorics.RankTree.unrank(num_leaves, rank) @@ -1600,7 +1600,7 @@ def root(self): :return: The root node. :rtype: int - :raises: :class:`ValueError` if this tree contains more than one root. + :raises ValueError: if this tree contains more than one root. """ if self.has_multiple_roots: raise ValueError("More than one root exists. Use tree.roots instead") @@ -5211,10 +5211,10 @@ def haplotypes( *Deprecated in 0.3.0. Use ``isolated_as_missing``, but inverting value. Will be removed in a future version* :rtype: collections.abc.Iterable - :raises: TypeError if the ``missing_data_character`` or any of the alleles + :raises TypeError: if the ``missing_data_character`` or any of the alleles at a site are not a single ascii character. - :raises: ValueError - if the ``missing_data_character`` exists in one of the alleles + :raises ValueError: if the ``missing_data_character`` exists in one of the + alleles """ if impute_missing_data is not None: warnings.warn( @@ -5521,10 +5521,9 @@ def alignments( :return: An iterator over the alignment strings for specified samples in this tree sequence, in the order given in ``samples``. :rtype: collections.abc.Iterable - :raises: ValueError - if any genome coordinate in this tree sequence is not discrete, - or if the ``reference_sequence`` is not of the correct length. - :raises: TypeError if any of the alleles at a site are not a + :raises ValueError: if any genome coordinate in this tree sequence is not + discrete, or if the ``reference_sequence`` is not of the correct length. + :raises TypeError: if any of the alleles at a site are not a single ascii character. """ if not self.discrete_genome: From 3340c62ef059549444555737a47d462cc9a24318 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 13 Jul 2023 14:24:11 +0100 Subject: [PATCH 69/84] Update tskit-book-theme --- .github/workflows/docs.yml | 2 +- docs/_config.yml | 3 + python/requirements/CI-docs/requirements.txt | 157 +------------------ 3 files changed, 10 insertions(+), 152 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 668470e1f6..d6c90ff434 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -48,7 +48,7 @@ jobs: python -m venv venv . venv/bin/activate pip install --upgrade pip wheel - cat ${{env.REQUIREMENTS}} | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 pip install --no-dependencies + pip install -r ${{env.REQUIREMENTS}} - name: Build C module if: env.MAKE_TARGET diff --git a/docs/_config.yml b/docs/_config.yml index e9ced63c29..c781ee3529 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -43,6 +43,9 @@ sphinx: config: html_theme: tskit_book_theme + html_theme_options: + pygment_light_style: monokai + pygment_dark_style: monokai pygments_style: monokai myst_enable_extensions: - colon_fence diff --git a/python/requirements/CI-docs/requirements.txt b/python/requirements/CI-docs/requirements.txt index 6569b3775a..e9733957ae 100644 --- a/python/requirements/CI-docs/requirements.txt +++ b/python/requirements/CI-docs/requirements.txt @@ -1,154 +1,9 @@ -# Due to issues with indirect dependencies introducing conflicting dependencies -# we freeze everything to get a reproducible build. -alabaster==0.7.13 -anyio==3.6.2 -argon2-cffi==21.3.0 -argon2-cffi-bindings==21.2.0 -arrow==1.2.3 -asttokens==2.2.1 -attrs==21.4.0 -Babel==2.11.0 -backcall==0.2.0 -beautifulsoup4==4.11.1 -bleach==5.0.1 -breathe==4.34.0 -certifi==2022.12.7 -cffi==1.15.1 -charset-normalizer==3.0.1 -click==8.1.3 -colorama==0.4.6 -comm==0.1.2 -debugpy==1.6.5 -decorator==5.1.1 -defusedxml==0.7.1 -demes==0.2.2 -Deprecated==1.2.13 -docutils==0.17.1 -entrypoints==0.4 -executing==1.2.0 -fastjsonschema==2.16.2 -fqdn==1.5.1 -gitdb==4.0.10 -GitPython==3.1.30 -greenlet==2.0.1 -idna==3.4 -imagesize==1.4.1 -importlib-metadata==6.0.0 -ipykernel==6.20.2 -ipython==8.8.0 -ipython-genutils==0.2.0 -ipywidgets==7.7.2 -isoduration==20.11.0 -jedi==0.18.2 -Jinja2==3.1.2 -jsonpointer==2.3 -jsonschema==4.17.3 -jupyter-book==0.13.1 -jupyter-cache==0.4.3 -jupyter-events==0.6.3 -jupyter-server-mathjax==0.2.6 -jupyter-sphinx==0.3.2 -jupyter_client==7.4.9 -jupyter_core==5.1.3 -jupyter_server==2.1.0 -jupyter_server_terminals==0.4.4 -jupyterlab-pygments==0.2.2 -jupyterlab-widgets==1.1.1 -latexcodec==2.0.1 -linkify-it-py==1.0.3 -lxml==4.9.2 -markdown-it-py==1.1.0 -MarkupSafe==2.1.2 -matplotlib-inline==0.1.6 -mdit-py-plugins==0.2.8 -mistune==0.8.4 -msprime==1.2.0 -myst-nb==0.13.2 -myst-parser==0.15.2 -nbclassic==0.4.8 -nbclient==0.5.13 -nbconvert==6.5.4 -nbdime==3.1.1 -nbformat==5.7.3 -nest-asyncio==1.5.6 -newick==1.6.0 -notebook==6.5.2 -notebook_shim==0.2.2 -numpy==1.24.1 -packaging==23.0 -pandocfilters==1.5.0 -parso==0.8.3 -pbr==5.11.1 -pexpect==4.8.0 -pickleshare==0.7.5 -platformdirs==2.6.2 -prometheus-client==0.15.0 -prompt-toolkit==3.0.36 -psutil==5.9.4 -ptyprocess==0.7.0 -pure-eval==0.2.2 -pybtex==0.24.0 -pybtex-docutils==1.0.2 -pycparser==2.21 -pydata-sphinx-theme==0.8.1 -PyGithub==1.57 -Pygments==2.14.0 -PyJWT==2.6.0 -PyNaCl==1.5.0 -pyrsistent==0.19.3 -python-dateutil==2.8.2 -python-json-logger==2.0.4 -pytz==2022.7.1 -PyYAML==6.0 -pyzmq==25.0.0 -requests==2.28.2 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -ruamel.yaml==0.17.21 -ruamel.yaml.clib==0.2.7 -Send2Trash==1.8.0 -six==1.16.0 -smmap==5.0.0 -sniffio==1.3.0 -snowballstemmer==2.2.0 -soupsieve==2.3.2.post1 -Sphinx==4.5.0 -sphinx-argparse==0.4.0 +jupyter-book==0.15.1 +breathe==4.35.0 sphinx-autodoc-typehints==1.19.1 -sphinx-book-theme==0.3.3 -sphinx-comments==0.0.3 -sphinx-copybutton==0.5.1 -sphinx-external-toc==0.2.4 sphinx-issues==3.0.1 -sphinx-jupyterbook-latex==0.4.7 -sphinx-multitoc-numbering==0.1.3 -sphinx-thebe==0.1.2 -sphinx-togglebutton==0.3.2 -sphinx_design==0.1.0 -sphinxcontrib-bibtex==2.5.0 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.0 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-prettyspecialmethods==0.1.0 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.5 -sphinxcontrib.applehelp==1.0.3 -SQLAlchemy==1.4.46 -stack-data==0.6.2 +sphinx-argparse==0.4.0 +numpy==1.25.1 svgwrite==1.4.3 -terminado==0.17.1 -tinycss2==1.2.1 -tornado==6.2 -traitlets==5.8.1 -tskit==0.5.4 -tskit-book-theme==0.3.2 -uc-micro-py==1.0.1 -uri-template==1.2.0 -urllib3==1.26.14 -wcwidth==0.2.6 -webcolors==1.12 -webencodings==0.5.1 -websocket-client==1.4.2 -widgetsnbextension==3.6.1 -wrapt==1.14.1 -zipp==3.11.0 +msprime==1.2.0 +tskit-book-theme \ No newline at end of file From ba202827a85c62a4283df12a4ffc6edbc5a5da38 Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 28 Nov 2022 13:09:48 -0800 Subject: [PATCH 70/84] first draft passes tests with C code, backwards even! --- c/tests/test_tables.c | 36 +++++ c/tskit/tables.c | 294 ++++++++++++++++++++++++++++++++-- c/tskit/tables.h | 24 +++ c/tskit/trees.h | 4 +- python/.gitignore | 1 + python/_tskitmodule.c | 31 ++++ python/tests/test_topology.py | 267 ++++++++++++++++++++++++++++++ python/tskit/tables.py | 6 + python/tskit/trees.py | 21 +++ 9 files changed, 667 insertions(+), 17 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 6de6675ff6..c6c6ac0053 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -2507,11 +2507,46 @@ test_edge_table_copy_semantics(void) tsk_treeseq_free(&ts); } +static void +test_extend_edges(void) +{ + int ret; + tsk_table_collection_t tables, tables_copy; + + const char *nodes_ex = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 2.0 -1 -1\n"; + const char *edges_ex = "0 10 2 0\n" + "0 10 2 1\n"; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 10; + + parse_nodes(nodes_ex, &tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 3); + parse_edges(edges_ex, &tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 2); + tsk_table_collection_build_index(&tables, 0); + tsk_table_collection_copy(&tables, &tables_copy, 0); + + ret = tsk_table_collection_extend_edges(&tables, 10, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(&tables, &tables_copy, 0)); + + // Free things. + tsk_table_collection_free(&tables); + tsk_table_collection_free(&tables_copy); +} + static void test_edge_table_squash(void) { int ret; tsk_table_collection_t tables; + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 10; const char *nodes_ex = "1 0 -1 -1\n" "1 0 -1 -1\n" @@ -11606,6 +11641,7 @@ main(int argc, char **argv) { "test_simplify_tables_drops_indexes", test_simplify_tables_drops_indexes }, { "test_simplify_empty_tables", test_simplify_empty_tables }, { "test_simplify_metadata", test_simplify_metadata }, + { "test_extend_edges", test_extend_edges }, { "test_link_ancestors_no_edges", test_link_ancestors_no_edges }, { "test_link_ancestors_input_errors", test_link_ancestors_input_errors }, { "test_link_ancestors_single_tree", test_link_ancestors_single_tree }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 8eea85f5ad..97ab0137ac 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -8321,8 +8321,8 @@ pair_to_integer(tsk_id_t a, tsk_id_t b, tsk_size_t N) static inline void integer_to_pair(int64_t index, tsk_size_t N, tsk_id_t *a, tsk_id_t *b) { - *a = (tsk_id_t)(index / (int64_t) N); - *b = (tsk_id_t)(index % (int64_t) N); + *a = (tsk_id_t) (index / (int64_t) N); + *b = (tsk_id_t) (index % (int64_t) N); } static int64_t @@ -10806,7 +10806,7 @@ tsk_table_collection_check_individual_integrity( /* Check parent references are valid */ if (individuals.parents[k] != TSK_NULL && (individuals.parents[k] < 0 - || individuals.parents[k] >= num_individuals)) { + || individuals.parents[k] >= num_individuals)) { ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; goto out; } @@ -11160,20 +11160,20 @@ tsk_table_collection_equals(const tsk_table_collection_t *self, if (!(options & TSK_CMP_IGNORE_TABLES)) { ret = ret && tsk_individual_table_equals( - &self->individuals, &other->individuals, options) + &self->individuals, &other->individuals, options) && tsk_node_table_equals(&self->nodes, &other->nodes, options) && tsk_edge_table_equals(&self->edges, &other->edges, options) && tsk_migration_table_equals( - &self->migrations, &other->migrations, options) + &self->migrations, &other->migrations, options) && tsk_site_table_equals(&self->sites, &other->sites, options) && tsk_mutation_table_equals(&self->mutations, &other->mutations, options) && tsk_population_table_equals( - &self->populations, &other->populations, options); + &self->populations, &other->populations, options); /* TSK_CMP_IGNORE_TABLES implies TSK_CMP_IGNORE_PROVENANCE */ if (!(options & TSK_CMP_IGNORE_PROVENANCE)) { ret = ret && tsk_provenance_table_equals( - &self->provenances, &other->provenances, options); + &self->provenances, &other->provenances, options); } } /* TSK_CMP_IGNORE_TS_METADATA is implied by TSK_CMP_IGNORE_METADATA */ @@ -11183,19 +11183,19 @@ tsk_table_collection_equals(const tsk_table_collection_t *self, if (!(options & TSK_CMP_IGNORE_TS_METADATA)) { ret = ret && (self->metadata_length == other->metadata_length - && self->metadata_schema_length == other->metadata_schema_length - && tsk_memcmp(self->metadata, other->metadata, - self->metadata_length * sizeof(char)) - == 0 - && tsk_memcmp(self->metadata_schema, other->metadata_schema, - self->metadata_schema_length * sizeof(char)) - == 0); + && self->metadata_schema_length == other->metadata_schema_length + && tsk_memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) + == 0 + && tsk_memcmp(self->metadata_schema, other->metadata_schema, + self->metadata_schema_length * sizeof(char)) + == 0); } if (!(options & TSK_CMP_IGNORE_REFERENCE_SEQUENCE)) { ret = ret && tsk_reference_sequence_equals( - &self->reference_sequence, &other->reference_sequence, options); + &self->reference_sequence, &other->reference_sequence, options); } return ret; } @@ -12885,6 +12885,270 @@ tsk_table_collection_add_and_remap_node(tsk_table_collection_t *self, return ret; } +typedef struct _edge_list_t { + tsk_id_t edge; + bool extended; // have we decided to extend this one on the current tree? + struct _edge_list_t *next; +} edge_list_t; + +static edge_list_t *TSK_WARN_UNUSED +extend_edges_alloc_entry(tsk_blkalloc_t *heap, tsk_id_t edge) +{ + // see ancestor_mapper_alloc_interval_list + edge_list_t *x = NULL; + + x = tsk_blkalloc_get(heap, sizeof(*x)); + if (x == NULL) { + goto out; + } + tsk_bug_assert(edge >= 0); + + x->edge = edge; + x->extended = false; + x->next = NULL; +out: + return x; +} + +static void +remove_unextended(edge_list_t **head, edge_list_t **tail) +{ + edge_list_t *px, *x; + + px = *head; + while (px != NULL && !px->extended) { + px = px->next; + } + *head = px; + if (px != NULL) { + px->extended = false; + x = px->next; + while (x != NULL) { + if (x->extended) { + // keep it + x->extended = false; + px->next = x; + px = x; + } + x = x->next; + } + } + *tail = px; +} + +static void +reverse_array(int *arr, int *dest, tsk_size_t n) +{ + for (tsk_size_t i = 0; i < n; i++) { + dest[i] = arr[n - i - 1]; + } +} + +static int +forward_extend(tsk_table_collection_t *self, int direction) +{ + int ret = 0; + double *new_left, *new_right; + tsk_id_t *num_edges; + tsk_id_t tj, tk, ret_id; + tsk_id_t e1, e2, e_in; + tsk_id_t *I, *O; + const tsk_id_t M = (tsk_id_t) self->edges.num_rows; + double left, right; + double *near_edge, *far_edge; + tsk_blkalloc_t edge_list_heap; + edge_list_t *edges_in_head, *edges_in_tail; + edge_list_t *edges_out_head, *edges_out_tail; + edge_list_t *x, *y, *ex1, *ex2, *ex_in; + tsk_edge_table_t edges; + tsk_edge_t edge; + double sign, here, there; + bool forwards; + + forwards = (direction == TSK_DIR_FORWARD); + + num_edges = tsk_malloc(self->nodes.num_rows * sizeof(*num_edges)); + new_left = tsk_malloc(self->edges.num_rows * sizeof(*new_left)); + new_right = tsk_malloc(self->edges.num_rows * sizeof(*new_right)); + if (num_edges == NULL || new_left == NULL || new_right == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tsk_memset(num_edges, 0x00, self->nodes.num_rows * sizeof(*num_edges)); + memcpy(new_left, self->edges.left, self->edges.num_rows * sizeof(*new_left)); + memcpy(new_right, self->edges.right, self->edges.num_rows * sizeof(*new_right)); + + ret = tsk_blkalloc_init(&edge_list_heap, 8192); + if (ret != 0) { + goto out; + } + if (forwards) { + I = self->indexes.edge_insertion_order; + O = self->indexes.edge_removal_order; + here = 0; + sign = 1; + near_edge = self->edges.left; + far_edge = self->edges.right; + } else { + I = tsk_malloc(self->edges.num_rows * sizeof(*I)); + O = tsk_malloc(self->edges.num_rows * sizeof(*O)); + reverse_array(self->indexes.edge_removal_order, I, self->edges.num_rows); + reverse_array(self->indexes.edge_insertion_order, O, self->edges.num_rows); + here = self->sequence_length; + sign = -1; + near_edge = self->edges.right; + far_edge = self->edges.left; + } + tj = 0; // current position in I + tk = 0; // current position in O + left = 0; + edges_in_head = NULL; + edges_in_tail = NULL; + edges_out_head = NULL; + edges_out_tail = NULL; + while (tj < M) { + // remove entries that aren't being extended/postponed + remove_unextended(&edges_in_head, &edges_in_tail); + remove_unextended(&edges_out_head, &edges_out_tail); + + while ((tk < M) && (far_edge[O[tk]] == here)) { + // add edge tk to pending_out + x = extend_edges_alloc_entry(&edge_list_heap, O[tk]); + if (edges_out_tail == NULL) { + edges_out_head = x; + } else { + y = edges_out_tail; + y->next = x; + } + edges_out_tail = x; + num_edges[self->edges.parent[O[tk]]] -= 1; + num_edges[self->edges.child[O[tk]]] -= 1; + tk++; + } + while ((tj < M) && (near_edge[I[tj]] == here)) { + // add edge tj to pending_in + x = extend_edges_alloc_entry(&edge_list_heap, I[tj]); + if (edges_in_tail == NULL) { + edges_in_head = x; + } else { + y = edges_in_tail; + y->next = x; + } + edges_in_tail = x; + num_edges[self->edges.parent[I[tj]]] += 1; + num_edges[self->edges.child[I[tj]]] += 1; + tj++; + } + there = forwards ? self->sequence_length : 0; + if (forwards) { + if (tk < M) { + there = TSK_MIN(there, far_edge[O[tk]]); + } + if (tj < M) { + there = TSK_MIN(there, near_edge[I[tj]]); + } + } else { + if (tk < M) { + there = TSK_MAX(there, far_edge[O[tk]]); + } + if (tj < M) { + there = TSK_MAX(there, near_edge[I[tj]]); + } + } + + // iterate over pairs of out and in: (ex1, ex2, in) + for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { + if (!ex1->extended) { + e1 = ex1->edge; + for (ex2 = edges_out_head; ex2 != NULL; ex2 = ex2->next) { + if (!ex2->extended) { + e2 = ex2->edge; + if ((self->edges.parent[e1] == self->edges.child[e2]) + && (num_edges[self->edges.child[e2]] == 0)) { + for (ex_in = edges_in_head; ex_in != NULL; + ex_in = ex_in->next) { + e_in = ex_in->edge; + if (sign * far_edge[e_in] > sign * here) { + if ((self->edges.child[e1] + == self->edges.child[e_in]) + && (self->edges.parent[e2] + == self->edges.parent[e_in])) { + ex1->extended = true; + ex2->extended = true; + ex_in->extended = true; + if (forwards) { + new_right[e1] = there; + new_right[e2] = there; + new_left[e_in] = there; + } else { + new_left[e1] = there; + new_left[e2] = there; + new_right[e_in] = there; + } + num_edges[self->edges.parent[e1]] += 2; + } + } + } + } + } + } + } + } + // cleanup at end of loop + here = there; + } + + // done! write out new edge tables + ret = tsk_edge_table_copy(&self->edges, &edges, 0); + if (ret != 0) { + goto out; + } + ret = tsk_edge_table_clear(&self->edges); + if (ret != 0) { + goto out; + } + for (tj = 0; tj < (tsk_id_t) edges.num_rows; tj++) { + left = new_left[tj]; + right = new_right[tj]; + if (left < right) { + tsk_edge_table_get_row_unsafe(&edges, tj, &edge); + ret_id = tsk_edge_table_add_row(&self->edges, left, right, edge.parent, + edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + ret = tsk_table_collection_build_index(self, 0); + if (ret != 0) { + goto out; + } + +out: + tsk_blkalloc_free(&edge_list_heap); + return ret; +} + +int TSK_WARN_UNUSED +tsk_table_collection_extend_edges( + tsk_table_collection_t *self, int max_iter, tsk_flags_t TSK_UNUSED(options)) +{ + int ret = 0; + tsk_size_t last_num_edges; + last_num_edges = self->edges.num_rows; + for (int j = 0; j <= max_iter; j++) { + forward_extend(self, TSK_DIR_FORWARD); + forward_extend(self, TSK_DIR_REVERSE); + if (self->edges.num_rows == last_num_edges) { + break; + } else { + last_num_edges = self->edges.num_rows; + } + } + return ret; +} + int TSK_WARN_UNUSED tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, tsk_size_t num_nodes, tsk_flags_t options) diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 38f3096c9d..872b9b8fa1 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -42,6 +42,14 @@ extern "C" { #include +/****************************************************************************/ +/* Generic definitions */ +/****************************************************************************/ + +// These are also used in trees.h +#define TSK_DIR_FORWARD 1 +#define TSK_DIR_REVERSE -1 + /****************************************************************************/ /* Definitions for the basic objects */ /****************************************************************************/ @@ -4351,6 +4359,22 @@ Options can be specified by providing one or more of the following bitwise int tsk_table_collection_simplify(tsk_table_collection_t *self, const tsk_id_t *samples, tsk_size_t num_samples, tsk_flags_t options, tsk_id_t *node_map); +/** +@brief Extends edges + +TODO DOCUMENT + +@rst + +**Options**: None currently defined. +@endrst + +@param self A pointer to a tsk_table_collection_t object. +@param options Bitwise option flags. (UNUSED) +@return Return 0 on success or a negative value on failure. +*/ +int tsk_table_collection_extend_edges(tsk_table_collection_t *self, int max_iter, tsk_flags_t options); + /** @brief Subsets and reorders a table collection according to an array of nodes. diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 95c66a6ac7..e512f8d67a 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -56,8 +56,8 @@ extern "C" { /* Options for map_mutations */ #define TSK_MM_FIXED_ANCESTRAL_STATE (1 << 0) -#define TSK_DIR_FORWARD 1 -#define TSK_DIR_REVERSE -1 +/* For the edge diff iterator */ +#define TSK_INCLUDE_TERMINAL (1 << 0) /** @defgroup API_FLAGS_TS_INIT_GROUP :c:func:`tsk_treeseq_init` specific flags. diff --git a/python/.gitignore b/python/.gitignore index 1d3e405ad1..acdfa61981 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -3,3 +3,4 @@ *.egg-info build .*.swp +*/.ipynb_checkpoints diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 5c6bd29986..5a865535ba 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6989,6 +6989,33 @@ TableCollection_link_ancestors(TableCollection *self, PyObject *args, PyObject * return ret; } +static PyObject * +TableCollection_extend_edges(TableCollection *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + int max_iter; + tsk_flags_t options = 0; + static char *kwlist[] = { "max_iter", NULL }; + + if (TableCollection_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "i", kwlist, &max_iter)) { + goto out; + } + + err = tsk_table_collection_extend_edges(self->tables, max_iter, options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + + static PyObject * TableCollection_subset(TableCollection *self, PyObject *args, PyObject *kwds) { @@ -7790,6 +7817,10 @@ static PyMethodDef TableCollection_methods[] = { .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns an edge table linking samples to a set of specified ancestors." }, + { .ml_name = "extend_edges", + .ml_meth = (PyCFunction) TableCollection_extend_edges, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Extends edges TODO DOCUMENT." }, { .ml_name = "subset", .ml_meth = (PyCFunction) TableCollection_subset, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index d564ec0590..3808d12390 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -8515,3 +8515,270 @@ def test_is_isolated_bad(self): tree.is_isolated("abc") with pytest.raises(TypeError): tree.is_isolated(1.1) + + +class TestExtendEdges: + """ + Test the 'extend edges' method + """ + + def py_extend_edges(): + ts = self + last_num_edges = ts.num_edges + for _ in range(max_iter): + ts = ts.forward_extend(forwards=True) + ts = ts.forward_extend(forwards=False) + if ts.num_edges == last_num_edges: + break + else: + last_num_edges = ts.num_edges + return ts + + def _extend(self, forwards=True): + print("forwards:", forwards) + num_edges = np.full(self.num_nodes, 0) + + t = self.tables + edges = t.edges.copy() + t.edges.clear() + new_left = edges.left + new_right = edges.right + + # edge diff stuff + M = edges.num_rows + if forwards: + I = self.indexes_edge_insertion_order + O = self.indexes_edge_removal_order + # "here" will be left if fowards else right + here = 0 + endpoint = self.sequence_length + sign = +1 + near_edge = edges.left + far_edge = edges.right + else: + I = np.flip(self.indexes_edge_removal_order) + O = np.flip(self.indexes_edge_insertion_order) + here = self.sequence_length + endpoint = 0 + sign = -1 + near_edge = edges.right + far_edge = edges.left + tj = 0 + tk = 0 + edges_out = [] + edges_in = [] + + while (tj < M): + # clear out non-extended or postponed edges + edges_out = [[e, False] for e, x in edges_out if x] + edges_in = [[e, False] for e, x in edges_in if x] + + # Find edges_out between trees + while (tk < M) and (far_edge[O[tk]] == here): + edges_out.append([O[tk], False]) + num_edges[edges.parent[O[tk]]] -= 1 + num_edges[edges.child[O[tk]]] -= 1 + #print("Edge Out", tk, edges[O[tk]]) + tk += 1 + # Find edges_in between trees + while (tj < M) and (near_edge[I[tj]] == here): + edges_in.append([I[tj], False]) + num_edges[edges.parent[I[tj]]] += 1 + num_edges[edges.child[I[tj]]] += 1 + #print("Edge In", tj, edges[I[tj]]) + tj += 1 + + # Find smallest length right endpoint of all edges in edges_in and edges_out + # there should equal the endpoint of a T_k + there = self.sequence_length if forwards else 0 + if forwards: + if tk < M: + there = min(there, far_edge[O[tk]]) + if tj < M: + there = min(there, near_edge[I[tj]]) + else: + if tk < M: + there = max(there, far_edge[O[tk]]) + if tj < M: + there = max(there, near_edge[I[tj]]) + print("All Edges Out", edges_out) + print("All Edges In", edges_in) + assert np.all(num_edges >= 0) + print("-------------", here, len(edges_out), len(edges_in)) + for ex1 in edges_out: + #print("e1:", e1, [edges.parent[O[e1]], edges.child[O[e1]]], edges[O[e1]]) + if not ex1[1]: + e1 = ex1[0] + for ex2 in edges_out: + #print("e2:", e2, num_edges[edges.child[e2]], ":", [edges.parent[O[e2]], edges.child[O[e2]]]) + if not ex2[1]: + # need the intermediate node to not be present in + # the new tree + e2 = ex2[0] + if ((edges.parent[e1] == edges.child[e2]) + and (num_edges[edges.child[e2]] == 0)): + for ex_in in edges_in: + e_in = ex_in[0] + #print("ein", e_in, [edges.parent[I[e_in]], edges.child[I[e_in]]]) + if sign * far_edge[e_in] > sign * here: + if ( + edges.child[e1] == edges.child[e_in] + and edges.parent[e2] == edges.parent[e_in] + ): + print("EXTEND") + # extend e2->e1 and postpone e_in + ex1[1] = True + ex2[1] = True + ex_in[1] = True + if forwards: + new_right[e1] = there + new_right[e2] = there + new_left[e_in] = there + else: + new_left[e1] = there + new_left[e2] = there + new_right[e_in] = there + # amend num_edges: the intermediate + # node has 2 edges instead of 0 + num_edges[edges.parent[e1]] += 2 + # cleanup at end of loop + here = there + + for j in range(edges.num_rows): + left = new_left[j] + right = new_right[j] + if left < right: + e = edges[j].replace(left=left, right=right) + t.edges.append(e) + t.build_index() + return t.tree_sequence() + + + def verify_extend_edges(self, ts, ets): + assert ts.num_samples == ets.num_samples + assert ts.num_nodes == ets.num_nodes + assert ts.num_edges >= ets.num_edges + t = ts.simplify().tables + et = ets.simplify().tables + t.assert_equals(et, ignore_provenance=True) + old_edges = {} + for e in ts.edges(): + k = (e.parent, e.child) + if k not in old_edges: + old_edges[k] = [] + old_edges[k].append((e.left, e.right)) + + for e in ets.edges(): + # e should be in old_edges, + # but with expanded limits + k = (e.parent, e.child) + assert k in old_edges + overlaps = False + for (l, r) in old_edges[k]: + if (l >= e.left) and (r <= e.right): + overlaps = True + assert overlaps + + chains = [] + for interval, tt, ett in ts.coiterate(ets): + print(interval) + print(tt.draw(format='ascii')) + print(ett.draw(format='ascii')) + this_chains = [] + for a in tt.nodes(): + b = tt.parent(a) + if b != tskit.NULL: + c = tt.parent(b) + if c != tskit.NULL: + this_chains.append((a, b, c)) + chains.append(this_chains) + + for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): + for j in (k-1, k+1): + if j < 0 or j >= len(chains): + next + else: + this_chains = chains[j] + print(j, this_chains) + for a, b, c in this_chains: + if a in tt.nodes() and tt.parent(a) == c and b not in tt.nodes(): + # the relationship a <- b <- c should still be in the tree, + # although maybe they aren't direct parent-offspring + print(a, b, c) + print("t:", list(tt.nodes())) + print("et:", list(ett.nodes())) + assert a in ett.nodes() + assert b in ett.nodes() + assert c in ett.nodes() + p = a + while p != tskit.NULL: + if p == b: + break + p = ett.parent(p) + assert p == b + while p != tskit.NULL: + if p == c: + break + p = ett.parent(p) + assert p == c + # TODO: compare C version to python version + + def test_runs(self): + ts = msprime.simulate(5, random_seed=126) + ets = ts.extend_edges() + + def test_simple_ex(self): + # this is an example by hand where you need to go forwards *and* backwards + # TODO actually test that to get everythign you have to! + # note that the test above might only test the forward pass + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 1.0, + 5: 1.0, + 6: 3.0, + 7: 2.0, + 8: 2.0, + } + # (p, c, l, r) + edge_stuff = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (6, 4, 0, 2), + (6, 4, 5, 10), + (6, 5, 0, 2), + (6, 5, 7, 10), + (6, 3, 2, 5), + (6, 7, 2, 5), + (6, 8, 5, 7), + (7, 2, 2, 5), + (7, 4, 2, 5), + (8, 1, 5, 7), + (8, 5, 5, 7), + ] + tables = tskit.TableCollection(sequence_length=10) + nodes = tables.nodes + for n, t in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n < 4 else 0 + nodes.add_row(time=t, flags=flags) + edges = tables.edges + for p, c, l, r in edge_stuff: + edges.add_row(parent=p, child=c, left=l, right=r) + tables.sort() + ts = tables.tree_sequence() + ets = ts.extend_edges() + self.verify_extend_edges(ts,ets) + + def test_extend_edges(self): + tables = wf.wf_sim(5, 20, deep_history=False, seed=3) + tables.sort() + ts = tables.tree_sequence().simplify() + ets = ts.extend_edges() + self.verify_extend_edges(ts, ets) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index cb2ff7d01f..c582654463 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4037,6 +4037,12 @@ def drop_index(self): """ self._ll_tables.drop_index() + def extend_edges(self, max_iter=100): + """ + TODO DOCUMENT + """ + self._ll_tables.extend_edges(max_iter) + def subset( self, nodes, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 177f357306..407f031e87 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6902,6 +6902,27 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None): tables.delete_older(time) return tables.tree_sequence() + + def extend_edges(self, max_iter=100): + ''' + Returns a new tree sequence whose unary nodes are extended to neighboring trees given the following condition: + While iterating over the tree sequence, in each tree, we identify connecting edges with unary nodes. + If an equivalent edge segment exists in the next tree without that unary node, + we extend the connecting edges from the previous tree into the next tree, + subsequently adding that unary node to the tree. + This in turn reduces the length of the edge just removed from the next tree, + and if its length becomes zero it is removed from the edge table. + + : param max_iters: (int) -- the number of iterations we analyze the tree sequence to edge extend. + The process will halt if there is no change in edge count over two consecutive iterations. Default = 100 + + :return: A new tree sequence with unary nodes extended across the tree sequence. + :rtype: tskit.TreeSequence + ''' + t = self.dump_tables() + t.extend_edges(max_iter=max_iter) + return t.tree_sequence() + def subset( self, nodes, From 0301bee5af5c4bdb733c8598a9a92a7e9bc95764 Mon Sep 17 00:00:00 2001 From: peter Date: Wed, 2 Aug 2023 14:55:16 -0700 Subject: [PATCH 71/84] removed reverse compare to python version and other tidying docstring flailing only have one edge table around --- c/tests/test_tables.c | 24 +++--- c/tskit/tables.c | 145 ++++++++++++++++------------------ python/tests/test_topology.py | 88 ++++++++------------- python/tskit/tables.py | 2 +- python/tskit/trees.py | 38 +++++---- 5 files changed, 134 insertions(+), 163 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index c6c6ac0053..b1f1e17dec 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -864,7 +864,7 @@ test_table_collection_metadata(void) takeset_metadata = tsk_malloc(example_metadata_length * sizeof(char)); CU_ASSERT_FATAL(takeset_metadata != NULL); memcpy(takeset_metadata, &example_metadata, - (size_t)(example_metadata_length * sizeof(char))); + (size_t) (example_metadata_length * sizeof(char))); ret = tsk_table_collection_init(&tc1, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -982,7 +982,7 @@ test_node_table(void) CU_ASSERT_EQUAL(table.individual[j], j); CU_ASSERT_EQUAL(table.num_rows, (tsk_size_t) j + 1); CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -1641,7 +1641,7 @@ test_edge_table_with_options(tsk_flags_t options) CU_ASSERT_EQUAL(table.metadata_offset, NULL); } else { CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -2533,8 +2533,7 @@ test_extend_edges(void) ret = tsk_table_collection_extend_edges(&tables, 10, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tables, &tables_copy, 0)); - - // Free things. + tsk_table_collection_free(&tables); tsk_table_collection_free(&tables_copy); } @@ -2544,9 +2543,6 @@ test_edge_table_squash(void) { int ret; tsk_table_collection_t tables; - ret = tsk_table_collection_init(&tables, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - tables.sequence_length = 10; const char *nodes_ex = "1 0 -1 -1\n" "1 0 -1 -1\n" @@ -4309,7 +4305,7 @@ test_migration_table(void) CU_ASSERT_EQUAL(table.time[j], j); CU_ASSERT_EQUAL(table.num_rows, (tsk_size_t) j + 1); CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -5030,7 +5026,7 @@ test_individual_table(void) table.location[spatial_dimension * (size_t) j + k], test_location[k]); } CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -5084,7 +5080,7 @@ test_individual_table(void) flags = tsk_malloc(num_rows * sizeof(tsk_flags_t)); CU_ASSERT_FATAL(flags != NULL); for (k = 0; k < num_rows; k++) { - flags[k] = (tsk_flags_t)(k + num_rows); + flags[k] = (tsk_flags_t) (k + num_rows); } location = tsk_malloc(spatial_dimension * num_rows * sizeof(double)); CU_ASSERT_FATAL(location != NULL); @@ -5099,7 +5095,7 @@ test_individual_table(void) parents = tsk_malloc(num_parents * num_rows * sizeof(tsk_id_t)); CU_ASSERT_FATAL(parents != NULL); for (k = 0; k < num_parents * num_rows; k++) { - parents[k] = (tsk_id_t)(k + (num_rows * 4)); + parents[k] = (tsk_id_t) (k + (num_rows * 4)); } parents_offset = tsk_malloc((num_rows + 1) * sizeof(tsk_size_t)); CU_ASSERT_FATAL(parents_offset != NULL); @@ -11459,9 +11455,9 @@ test_table_collection_takeset_indexes(void) rem = tsk_malloc(t1.edges.num_rows * sizeof(*rem)); CU_ASSERT_FATAL(rem != NULL); memcpy(ins, t1.indexes.edge_insertion_order, - (size_t)(t1.edges.num_rows * sizeof(*ins))); + (size_t) (t1.edges.num_rows * sizeof(*ins))); memcpy( - rem, t1.indexes.edge_removal_order, (size_t)(t1.edges.num_rows * sizeof(*rem))); + rem, t1.indexes.edge_removal_order, (size_t) (t1.edges.num_rows * sizeof(*rem))); ret = tsk_table_collection_copy(&t1, &t2, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 97ab0137ac..131414461c 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12887,14 +12887,15 @@ tsk_table_collection_add_and_remap_node(tsk_table_collection_t *self, typedef struct _edge_list_t { tsk_id_t edge; - bool extended; // have we decided to extend this one on the current tree? + // the `extended` flags records whether we have decided to extend + // this entry to the current tree? + bool extended; struct _edge_list_t *next; } edge_list_t; static edge_list_t *TSK_WARN_UNUSED extend_edges_alloc_entry(tsk_blkalloc_t *heap, tsk_id_t edge) { - // see ancestor_mapper_alloc_interval_list edge_list_t *x = NULL; x = tsk_blkalloc_get(heap, sizeof(*x)); @@ -12925,7 +12926,6 @@ remove_unextended(edge_list_t **head, edge_list_t **tail) x = px->next; while (x != NULL) { if (x->extended) { - // keep it x->extended = false; px->next = x; px = x; @@ -12936,26 +12936,16 @@ remove_unextended(edge_list_t **head, edge_list_t **tail) *tail = px; } -static void -reverse_array(int *arr, int *dest, tsk_size_t n) -{ - for (tsk_size_t i = 0; i < n; i++) { - dest[i] = arr[n - i - 1]; - } -} - static int forward_extend(tsk_table_collection_t *self, int direction) { int ret = 0; - double *new_left, *new_right; tsk_id_t *num_edges; tsk_id_t tj, tk, ret_id; tsk_id_t e1, e2, e_in; tsk_id_t *I, *O; const tsk_id_t M = (tsk_id_t) self->edges.num_rows; - double left, right; - double *near_edge, *far_edge; + double *near_side, *far_side; tsk_blkalloc_t edge_list_heap; edge_list_t *edges_in_head, *edges_in_tail; edge_list_t *edges_out_head, *edges_out_tail; @@ -12963,20 +12953,24 @@ forward_extend(tsk_table_collection_t *self, int direction) tsk_edge_table_t edges; tsk_edge_t edge; double sign, here, there; - bool forwards; - - forwards = (direction == TSK_DIR_FORWARD); + tsk_id_t sign_int; + bool forwards = (direction == TSK_DIR_FORWARD); num_edges = tsk_malloc(self->nodes.num_rows * sizeof(*num_edges)); - new_left = tsk_malloc(self->edges.num_rows * sizeof(*new_left)); - new_right = tsk_malloc(self->edges.num_rows * sizeof(*new_right)); - if (num_edges == NULL || new_left == NULL || new_right == NULL) { + if (num_edges == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } tsk_memset(num_edges, 0x00, self->nodes.num_rows * sizeof(*num_edges)); - memcpy(new_left, self->edges.left, self->edges.num_rows * sizeof(*new_left)); - memcpy(new_right, self->edges.right, self->edges.num_rows * sizeof(*new_right)); + + ret = tsk_edge_table_copy(&self->edges, &edges, 0); + if (ret != 0) { + goto out; + } + ret = tsk_edge_table_clear(&self->edges); + if (ret != 0) { + goto out; + } ret = tsk_blkalloc_init(&edge_list_heap, 8192); if (ret != 0) { @@ -12987,33 +12981,38 @@ forward_extend(tsk_table_collection_t *self, int direction) O = self->indexes.edge_removal_order; here = 0; sign = 1; - near_edge = self->edges.left; - far_edge = self->edges.right; + sign_int = 1; + near_side = edges.left; + far_side = edges.right; + tj = 0; + tk = 0; } else { - I = tsk_malloc(self->edges.num_rows * sizeof(*I)); - O = tsk_malloc(self->edges.num_rows * sizeof(*O)); - reverse_array(self->indexes.edge_removal_order, I, self->edges.num_rows); - reverse_array(self->indexes.edge_insertion_order, O, self->edges.num_rows); + O = self->indexes.edge_insertion_order; + I = self->indexes.edge_removal_order; here = self->sequence_length; sign = -1; - near_edge = self->edges.right; - far_edge = self->edges.left; + sign_int = -1; + near_side = edges.right; + far_side = edges.left; + tj = M - 1; + tk = M - 1; } - tj = 0; // current position in I - tk = 0; // current position in O - left = 0; edges_in_head = NULL; edges_in_tail = NULL; edges_out_head = NULL; edges_out_tail = NULL; - while (tj < M) { + while ((tj < M) && (tj >= 0)) { // remove entries that aren't being extended/postponed remove_unextended(&edges_in_head, &edges_in_tail); remove_unextended(&edges_out_head, &edges_out_tail); - while ((tk < M) && (far_edge[O[tk]] == here)) { + while (((tk < M) && (tk >= 0)) && (far_side[O[tk]] == here)) { // add edge tk to pending_out x = extend_edges_alloc_entry(&edge_list_heap, O[tk]); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } if (edges_out_tail == NULL) { edges_out_head = x; } else { @@ -13021,13 +13020,17 @@ forward_extend(tsk_table_collection_t *self, int direction) y->next = x; } edges_out_tail = x; - num_edges[self->edges.parent[O[tk]]] -= 1; - num_edges[self->edges.child[O[tk]]] -= 1; - tk++; + num_edges[edges.parent[O[tk]]] -= 1; + num_edges[edges.child[O[tk]]] -= 1; + tk += sign_int; } - while ((tj < M) && (near_edge[I[tj]] == here)) { + while (((tj < M) && (tj >= 0)) && (near_side[I[tj]] == here)) { // add edge tj to pending_in x = extend_edges_alloc_entry(&edge_list_heap, I[tj]); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } if (edges_in_tail == NULL) { edges_in_head = x; } else { @@ -13035,24 +13038,24 @@ forward_extend(tsk_table_collection_t *self, int direction) y->next = x; } edges_in_tail = x; - num_edges[self->edges.parent[I[tj]]] += 1; - num_edges[self->edges.child[I[tj]]] += 1; - tj++; + num_edges[edges.parent[I[tj]]] += 1; + num_edges[edges.child[I[tj]]] += 1; + tj += sign_int; } there = forwards ? self->sequence_length : 0; if (forwards) { if (tk < M) { - there = TSK_MIN(there, far_edge[O[tk]]); + there = TSK_MIN(there, far_side[O[tk]]); } if (tj < M) { - there = TSK_MIN(there, near_edge[I[tj]]); + there = TSK_MIN(there, near_side[I[tj]]); } } else { - if (tk < M) { - there = TSK_MAX(there, far_edge[O[tk]]); + if (tk >= 0) { + there = TSK_MAX(there, far_side[O[tk]]); } - if (tj < M) { - there = TSK_MAX(there, near_edge[I[tj]]); + if (tj >= 0) { + there = TSK_MAX(there, near_side[I[tj]]); } } @@ -13063,29 +13066,27 @@ forward_extend(tsk_table_collection_t *self, int direction) for (ex2 = edges_out_head; ex2 != NULL; ex2 = ex2->next) { if (!ex2->extended) { e2 = ex2->edge; - if ((self->edges.parent[e1] == self->edges.child[e2]) - && (num_edges[self->edges.child[e2]] == 0)) { + if ((edges.parent[e1] == edges.child[e2]) + && (num_edges[edges.child[e2]] == 0)) { for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { e_in = ex_in->edge; - if (sign * far_edge[e_in] > sign * here) { - if ((self->edges.child[e1] - == self->edges.child[e_in]) - && (self->edges.parent[e2] - == self->edges.parent[e_in])) { + if (sign * far_side[e_in] > sign * here) { + if ((edges.child[e1] == edges.child[e_in]) + && (edges.parent[e2] == edges.parent[e_in])) { ex1->extended = true; ex2->extended = true; ex_in->extended = true; if (forwards) { - new_right[e1] = there; - new_right[e2] = there; - new_left[e_in] = there; + edges.right[e1] = there; + edges.right[e2] = there; + edges.left[e_in] = there; } else { - new_left[e1] = there; - new_left[e2] = there; - new_right[e_in] = there; + edges.left[e1] = there; + edges.left[e2] = there; + edges.right[e_in] = there; } - num_edges[self->edges.parent[e1]] += 2; + num_edges[edges.parent[e1]] += 2; } } } @@ -13099,21 +13100,11 @@ forward_extend(tsk_table_collection_t *self, int direction) } // done! write out new edge tables - ret = tsk_edge_table_copy(&self->edges, &edges, 0); - if (ret != 0) { - goto out; - } - ret = tsk_edge_table_clear(&self->edges); - if (ret != 0) { - goto out; - } for (tj = 0; tj < (tsk_id_t) edges.num_rows; tj++) { - left = new_left[tj]; - right = new_right[tj]; - if (left < right) { - tsk_edge_table_get_row_unsafe(&edges, tj, &edge); - ret_id = tsk_edge_table_add_row(&self->edges, left, right, edge.parent, - edge.child, edge.metadata, edge.metadata_length); + tsk_edge_table_get_row_unsafe(&edges, tj, &edge); + if (edge.left < edge.right) { + ret_id = tsk_edge_table_add_row(&self->edges, edge.left, edge.right, + edge.parent, edge.child, edge.metadata, edge.metadata_length); if (ret_id < 0) { ret = (int) ret_id; goto out; @@ -13127,6 +13118,8 @@ forward_extend(tsk_table_collection_t *self, int direction) out: tsk_blkalloc_free(&edge_list_heap); + tsk_safe_free(num_edges); + tsk_edge_table_free(&edges); return ret; } diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 3808d12390..821a77be41 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -8522,23 +8522,21 @@ class TestExtendEdges: Test the 'extend edges' method """ - def py_extend_edges(): - ts = self + def py_extend_edges(self, ts, max_iter=10): last_num_edges = ts.num_edges for _ in range(max_iter): - ts = ts.forward_extend(forwards=True) - ts = ts.forward_extend(forwards=False) + ts = self._extend(ts, forwards=True) + ts = self._extend(ts, forwards=False) if ts.num_edges == last_num_edges: break else: last_num_edges = ts.num_edges return ts - def _extend(self, forwards=True): - print("forwards:", forwards) - num_edges = np.full(self.num_nodes, 0) + def _extend(self, ts, forwards=True): + num_edges = np.full(ts.num_nodes, 0) - t = self.tables + t = ts.tables edges = t.edges.copy() t.edges.clear() new_left = edges.left @@ -8547,19 +8545,17 @@ def _extend(self, forwards=True): # edge diff stuff M = edges.num_rows if forwards: - I = self.indexes_edge_insertion_order - O = self.indexes_edge_removal_order + I = ts.indexes_edge_insertion_order # NOQA E741 + O = ts.indexes_edge_removal_order # NOQA E741 # "here" will be left if fowards else right here = 0 - endpoint = self.sequence_length sign = +1 near_edge = edges.left far_edge = edges.right else: - I = np.flip(self.indexes_edge_removal_order) - O = np.flip(self.indexes_edge_insertion_order) - here = self.sequence_length - endpoint = 0 + I = np.flip(ts.indexes_edge_removal_order) # NOQA E741 + O = np.flip(ts.indexes_edge_insertion_order) # NOQA E741 + here = ts.sequence_length sign = -1 near_edge = edges.right far_edge = edges.left @@ -8568,65 +8564,53 @@ def _extend(self, forwards=True): edges_out = [] edges_in = [] - while (tj < M): + while tj < M: # clear out non-extended or postponed edges edges_out = [[e, False] for e, x in edges_out if x] edges_in = [[e, False] for e, x in edges_in if x] - - # Find edges_out between trees + while (tk < M) and (far_edge[O[tk]] == here): edges_out.append([O[tk], False]) num_edges[edges.parent[O[tk]]] -= 1 num_edges[edges.child[O[tk]]] -= 1 - #print("Edge Out", tk, edges[O[tk]]) tk += 1 - # Find edges_in between trees + while (tj < M) and (near_edge[I[tj]] == here): edges_in.append([I[tj], False]) num_edges[edges.parent[I[tj]]] += 1 num_edges[edges.child[I[tj]]] += 1 - #print("Edge In", tj, edges[I[tj]]) tj += 1 - - # Find smallest length right endpoint of all edges in edges_in and edges_out - # there should equal the endpoint of a T_k - there = self.sequence_length if forwards else 0 + + there = ts.sequence_length if forwards else 0 if forwards: if tk < M: there = min(there, far_edge[O[tk]]) if tj < M: there = min(there, near_edge[I[tj]]) - else: + else: if tk < M: there = max(there, far_edge[O[tk]]) if tj < M: there = max(there, near_edge[I[tj]]) - print("All Edges Out", edges_out) - print("All Edges In", edges_in) assert np.all(num_edges >= 0) - print("-------------", here, len(edges_out), len(edges_in)) for ex1 in edges_out: - #print("e1:", e1, [edges.parent[O[e1]], edges.child[O[e1]]], edges[O[e1]]) if not ex1[1]: e1 = ex1[0] for ex2 in edges_out: - #print("e2:", e2, num_edges[edges.child[e2]], ":", [edges.parent[O[e2]], edges.child[O[e2]]]) if not ex2[1]: - # need the intermediate node to not be present in + # the intermediate node should not be present in # the new tree e2 = ex2[0] - if ((edges.parent[e1] == edges.child[e2]) - and (num_edges[edges.child[e2]] == 0)): + if (edges.parent[e1] == edges.child[e2]) and ( + num_edges[edges.child[e2]] == 0 + ): for ex_in in edges_in: e_in = ex_in[0] - #print("ein", e_in, [edges.parent[I[e_in]], edges.child[I[e_in]]]) if sign * far_edge[e_in] > sign * here: if ( edges.child[e1] == edges.child[e_in] and edges.parent[e2] == edges.parent[e_in] ): - print("EXTEND") - # extend e2->e1 and postpone e_in ex1[1] = True ex2[1] = True ex_in[1] = True @@ -8653,8 +8637,8 @@ def _extend(self, forwards=True): t.build_index() return t.tree_sequence() - - def verify_extend_edges(self, ts, ets): + def verify_extend_edges(self, ts): + ets = ts.extend_edges() assert ts.num_samples == ets.num_samples assert ts.num_nodes == ets.num_nodes assert ts.num_edges >= ets.num_edges @@ -8680,10 +8664,7 @@ def verify_extend_edges(self, ts, ets): assert overlaps chains = [] - for interval, tt, ett in ts.coiterate(ets): - print(interval) - print(tt.draw(format='ascii')) - print(ett.draw(format='ascii')) + for _, tt, _ett in ts.coiterate(ets): this_chains = [] for a in tt.nodes(): b = tt.parent(a) @@ -8693,20 +8674,16 @@ def verify_extend_edges(self, ts, ets): this_chains.append((a, b, c)) chains.append(this_chains) - for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): - for j in (k-1, k+1): + for k, (_, tt, ett) in enumerate(ts.coiterate(ets)): + for j in (k - 1, k + 1): if j < 0 or j >= len(chains): next else: this_chains = chains[j] - print(j, this_chains) for a, b, c in this_chains: if a in tt.nodes() and tt.parent(a) == c and b not in tt.nodes(): # the relationship a <- b <- c should still be in the tree, # although maybe they aren't direct parent-offspring - print(a, b, c) - print("t:", list(tt.nodes())) - print("et:", list(ett.nodes())) assert a in ett.nodes() assert b in ett.nodes() assert c in ett.nodes() @@ -8721,11 +8698,14 @@ def verify_extend_edges(self, ts, ets): break p = ett.parent(p) assert p == c - # TODO: compare C version to python version + # finally, compare C version to python version + py_et = self.py_extend_edges(ts).dump_tables() + et = ets.dump_tables() + et.assert_equals(py_et) def test_runs(self): ts = msprime.simulate(5, random_seed=126) - ets = ts.extend_edges() + self.verify_extend_edges(ts) def test_simple_ex(self): # this is an example by hand where you need to go forwards *and* backwards @@ -8773,12 +8753,10 @@ def test_simple_ex(self): edges.add_row(parent=p, child=c, left=l, right=r) tables.sort() ts = tables.tree_sequence() - ets = ts.extend_edges() - self.verify_extend_edges(ts,ets) + self.verify_extend_edges(ts) def test_extend_edges(self): tables = wf.wf_sim(5, 20, deep_history=False, seed=3) tables.sort() ts = tables.tree_sequence().simplify() - ets = ts.extend_edges() - self.verify_extend_edges(ts, ets) + self.verify_extend_edges(ts) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index c582654463..9a95e3e2b8 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4037,7 +4037,7 @@ def drop_index(self): """ self._ll_tables.drop_index() - def extend_edges(self, max_iter=100): + def extend_edges(self, max_iter=10): """ TODO DOCUMENT """ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 407f031e87..74c010ff9b 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6902,23 +6902,27 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None): tables.delete_older(time) return tables.tree_sequence() - - def extend_edges(self, max_iter=100): - ''' - Returns a new tree sequence whose unary nodes are extended to neighboring trees given the following condition: - While iterating over the tree sequence, in each tree, we identify connecting edges with unary nodes. - If an equivalent edge segment exists in the next tree without that unary node, - we extend the connecting edges from the previous tree into the next tree, - subsequently adding that unary node to the tree. - This in turn reduces the length of the edge just removed from the next tree, - and if its length becomes zero it is removed from the edge table. - - : param max_iters: (int) -- the number of iterations we analyze the tree sequence to edge extend. - The process will halt if there is no change in edge count over two consecutive iterations. Default = 100 - - :return: A new tree sequence with unary nodes extended across the tree sequence. - :rtype: tskit.TreeSequence - ''' + def extend_edges(self, max_iter=10): + """ + TODO: make this better + Returns a new tree sequence in which the span covered by ancestral nodes + is "extended" to regions of the genome over which their ancestry is + unambiguous, which occurs if the node is an intermediate in a chain + of ancestry that also exists in neighboring regions of the genome. + While iterating over the tree sequence, in each tree, we identify + connecting edges with unary nodes. If an equivalent edge segment + exists in the next tree without that unary node, we extend the + connecting edges from the previous tree into the next tree, + subsequently adding that unary node to the tree. This in turn reduces + the length of the edge just removed from the next tree, and if its + length becomes zero it is removed from the edge table. + + :param int max_iters: The maximum number of forward-and-backward + iterations over the tree sequence. Defaults to 10. + + :return: A new tree sequence with unary nodes extended. + :rtype: tskit.TreeSequence + """ t = self.dump_tables() t.extend_edges(max_iter=max_iter) return t.tree_sequence() From b030cf107f71e0794f33294737b4607c0964da48 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 3 Aug 2023 10:51:47 -0700 Subject: [PATCH 72/84] make docs build --- c/tskit/tables.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 872b9b8fa1..4336bc1249 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -4370,10 +4370,12 @@ TODO DOCUMENT @endrst @param self A pointer to a tsk_table_collection_t object. +@param max_iter The maximum number of iterations over the tree sequence. @param options Bitwise option flags. (UNUSED) @return Return 0 on success or a negative value on failure. */ -int tsk_table_collection_extend_edges(tsk_table_collection_t *self, int max_iter, tsk_flags_t options); +int tsk_table_collection_extend_edges( + tsk_table_collection_t *self, int max_iter, tsk_flags_t options); /** @brief Subsets and reorders a table collection according to an array of nodes. From 426b9035766df4c3ae6f5c8551736a79b6fa301c Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 3 Aug 2023 18:21:40 -0700 Subject: [PATCH 73/84] clang-format-6'ed --- c/tests/test_tables.c | 18 +++++++++--------- c/tskit/tables.c | 30 +++++++++++++++--------------- python/_tskitmodule.c | 1 - 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index b1f1e17dec..5c96cced0e 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -864,7 +864,7 @@ test_table_collection_metadata(void) takeset_metadata = tsk_malloc(example_metadata_length * sizeof(char)); CU_ASSERT_FATAL(takeset_metadata != NULL); memcpy(takeset_metadata, &example_metadata, - (size_t) (example_metadata_length * sizeof(char))); + (size_t)(example_metadata_length * sizeof(char))); ret = tsk_table_collection_init(&tc1, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -982,7 +982,7 @@ test_node_table(void) CU_ASSERT_EQUAL(table.individual[j], j); CU_ASSERT_EQUAL(table.num_rows, (tsk_size_t) j + 1); CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -1641,7 +1641,7 @@ test_edge_table_with_options(tsk_flags_t options) CU_ASSERT_EQUAL(table.metadata_offset, NULL); } else { CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -4305,7 +4305,7 @@ test_migration_table(void) CU_ASSERT_EQUAL(table.time[j], j); CU_ASSERT_EQUAL(table.num_rows, (tsk_size_t) j + 1); CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -5026,7 +5026,7 @@ test_individual_table(void) table.location[spatial_dimension * (size_t) j + k], test_location[k]); } CU_ASSERT_EQUAL( - table.metadata_length, (tsk_size_t) (j + 1) * test_metadata_length); + table.metadata_length, (tsk_size_t)(j + 1) * test_metadata_length); CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); /* check the metadata */ tsk_memcpy(metadata_copy, table.metadata + table.metadata_offset[j], @@ -5080,7 +5080,7 @@ test_individual_table(void) flags = tsk_malloc(num_rows * sizeof(tsk_flags_t)); CU_ASSERT_FATAL(flags != NULL); for (k = 0; k < num_rows; k++) { - flags[k] = (tsk_flags_t) (k + num_rows); + flags[k] = (tsk_flags_t)(k + num_rows); } location = tsk_malloc(spatial_dimension * num_rows * sizeof(double)); CU_ASSERT_FATAL(location != NULL); @@ -5095,7 +5095,7 @@ test_individual_table(void) parents = tsk_malloc(num_parents * num_rows * sizeof(tsk_id_t)); CU_ASSERT_FATAL(parents != NULL); for (k = 0; k < num_parents * num_rows; k++) { - parents[k] = (tsk_id_t) (k + (num_rows * 4)); + parents[k] = (tsk_id_t)(k + (num_rows * 4)); } parents_offset = tsk_malloc((num_rows + 1) * sizeof(tsk_size_t)); CU_ASSERT_FATAL(parents_offset != NULL); @@ -11455,9 +11455,9 @@ test_table_collection_takeset_indexes(void) rem = tsk_malloc(t1.edges.num_rows * sizeof(*rem)); CU_ASSERT_FATAL(rem != NULL); memcpy(ins, t1.indexes.edge_insertion_order, - (size_t) (t1.edges.num_rows * sizeof(*ins))); + (size_t)(t1.edges.num_rows * sizeof(*ins))); memcpy( - rem, t1.indexes.edge_removal_order, (size_t) (t1.edges.num_rows * sizeof(*rem))); + rem, t1.indexes.edge_removal_order, (size_t)(t1.edges.num_rows * sizeof(*rem))); ret = tsk_table_collection_copy(&t1, &t2, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 131414461c..a10c2c33bb 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -8321,8 +8321,8 @@ pair_to_integer(tsk_id_t a, tsk_id_t b, tsk_size_t N) static inline void integer_to_pair(int64_t index, tsk_size_t N, tsk_id_t *a, tsk_id_t *b) { - *a = (tsk_id_t) (index / (int64_t) N); - *b = (tsk_id_t) (index % (int64_t) N); + *a = (tsk_id_t)(index / (int64_t) N); + *b = (tsk_id_t)(index % (int64_t) N); } static int64_t @@ -10806,7 +10806,7 @@ tsk_table_collection_check_individual_integrity( /* Check parent references are valid */ if (individuals.parents[k] != TSK_NULL && (individuals.parents[k] < 0 - || individuals.parents[k] >= num_individuals)) { + || individuals.parents[k] >= num_individuals)) { ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; goto out; } @@ -11160,20 +11160,20 @@ tsk_table_collection_equals(const tsk_table_collection_t *self, if (!(options & TSK_CMP_IGNORE_TABLES)) { ret = ret && tsk_individual_table_equals( - &self->individuals, &other->individuals, options) + &self->individuals, &other->individuals, options) && tsk_node_table_equals(&self->nodes, &other->nodes, options) && tsk_edge_table_equals(&self->edges, &other->edges, options) && tsk_migration_table_equals( - &self->migrations, &other->migrations, options) + &self->migrations, &other->migrations, options) && tsk_site_table_equals(&self->sites, &other->sites, options) && tsk_mutation_table_equals(&self->mutations, &other->mutations, options) && tsk_population_table_equals( - &self->populations, &other->populations, options); + &self->populations, &other->populations, options); /* TSK_CMP_IGNORE_TABLES implies TSK_CMP_IGNORE_PROVENANCE */ if (!(options & TSK_CMP_IGNORE_PROVENANCE)) { ret = ret && tsk_provenance_table_equals( - &self->provenances, &other->provenances, options); + &self->provenances, &other->provenances, options); } } /* TSK_CMP_IGNORE_TS_METADATA is implied by TSK_CMP_IGNORE_METADATA */ @@ -11183,19 +11183,19 @@ tsk_table_collection_equals(const tsk_table_collection_t *self, if (!(options & TSK_CMP_IGNORE_TS_METADATA)) { ret = ret && (self->metadata_length == other->metadata_length - && self->metadata_schema_length == other->metadata_schema_length - && tsk_memcmp(self->metadata, other->metadata, - self->metadata_length * sizeof(char)) - == 0 - && tsk_memcmp(self->metadata_schema, other->metadata_schema, - self->metadata_schema_length * sizeof(char)) - == 0); + && self->metadata_schema_length == other->metadata_schema_length + && tsk_memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) + == 0 + && tsk_memcmp(self->metadata_schema, other->metadata_schema, + self->metadata_schema_length * sizeof(char)) + == 0); } if (!(options & TSK_CMP_IGNORE_REFERENCE_SEQUENCE)) { ret = ret && tsk_reference_sequence_equals( - &self->reference_sequence, &other->reference_sequence, options); + &self->reference_sequence, &other->reference_sequence, options); } return ret; } diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 5a865535ba..55dd1c3253 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7015,7 +7015,6 @@ TableCollection_extend_edges(TableCollection *self, PyObject *args, PyObject *kw return ret; } - static PyObject * TableCollection_subset(TableCollection *self, PyObject *args, PyObject *kwds) { From 8a7c84ff08e51676e66dcd446a0056e2a3341679 Mon Sep 17 00:00:00 2001 From: peter Date: Fri, 4 Aug 2023 07:23:53 -0700 Subject: [PATCH 74/84] C tidyup --- c/tests/test_tables.c | 75 ++++++++++++++++++++++++++++++++++- c/tskit/tables.c | 67 +++++++++++++++++++------------ python/tests/test_topology.py | 43 ++++++++++++++------ 3 files changed, 145 insertions(+), 40 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 5c96cced0e..58386894ab 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -2508,7 +2508,7 @@ test_edge_table_copy_semantics(void) } static void -test_extend_edges(void) +test_extend_edges_simple(void) { int ret; tsk_table_collection_t tables, tables_copy; @@ -2527,9 +2527,13 @@ test_extend_edges(void) CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 3); parse_edges(edges_ex, &tables.edges); CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 2); + + tsk_table_collection_drop_index(&tables, 0); + ret = tsk_table_collection_extend_edges(&tables, 10, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TABLES_NOT_INDEXED); + tsk_table_collection_build_index(&tables, 0); tsk_table_collection_copy(&tables, &tables_copy, 0); - ret = tsk_table_collection_extend_edges(&tables, 10, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tables, &tables_copy, 0)); @@ -2538,6 +2542,72 @@ test_extend_edges(void) tsk_table_collection_free(&tables_copy); } +static void +test_extend_edges(void) +{ + int ret; + tsk_table_collection_t tables; + /* 7 and 8 should be extended to the whole sequence + + 6 6 6 6 + +-+-+ +-+-+ +-+-+ +-+-+ + | | 7 | | 8 | | + | | ++-+ | | +-++ | | + 4 5 4 | | 4 | 5 4 5 + +++ +++ +++ | | | | +++ +++ +++ + 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + */ + + const char *nodes_ex = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 3.0 -1 -1\n" + "0 2.0 -1 -1\n" + "0 2.0 -1 -1\n"; + // l, r, p, c + const char *edges_ex = "0 10 4 0\n" + "0 5 4 1\n" + "7 10 4 1\n" + "0 2 5 2\n" + "5 10 5 2\n" + "0 2 5 3\n" + "5 10 5 3\n" + "0 2 6 4\n" + "5 10 6 4\n" + "0 2 6 5\n" + "7 10 6 5\n" + "2 5 6 3\n" + "2 5 6 7\n" + "5 7 6 8\n" + "2 5 7 2\n" + "2 5 7 4\n" + "5 7 8 1\n" + "5 7 8 5\n"; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 10; + + parse_nodes(nodes_ex, &tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); + parse_edges(edges_ex, &tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 18); + ret = tsk_table_collection_sort(&tables, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_build_index(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_table_collection_extend_edges(&tables, 10, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); + CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 13); + + tsk_table_collection_free(&tables); +} + static void test_edge_table_squash(void) { @@ -11637,6 +11707,7 @@ main(int argc, char **argv) { "test_simplify_tables_drops_indexes", test_simplify_tables_drops_indexes }, { "test_simplify_empty_tables", test_simplify_empty_tables }, { "test_simplify_metadata", test_simplify_metadata }, + { "test_extend_edges_simple", test_extend_edges_simple }, { "test_extend_edges", test_extend_edges }, { "test_link_ancestors_no_edges", test_link_ancestors_no_edges }, { "test_link_ancestors_input_errors", test_link_ancestors_input_errors }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index a10c2c33bb..3dff91728e 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12939,8 +12939,12 @@ remove_unextended(edge_list_t **head, edge_list_t **tail) static int forward_extend(tsk_table_collection_t *self, int direction) { + // Note: this modifies the edge table, but it does this by (a) removing + // some edges, and (b) extending left/right endpoints of others, + // while keeping order the same, and so this maintains sortedness + // (so, there is no need to sort afterwards). int ret = 0; - tsk_id_t *num_edges; + tsk_id_t *num_children; tsk_id_t tj, tk, ret_id; tsk_id_t e1, e2, e_in; tsk_id_t *I, *O; @@ -12956,26 +12960,35 @@ forward_extend(tsk_table_collection_t *self, int direction) tsk_id_t sign_int; bool forwards = (direction == TSK_DIR_FORWARD); - num_edges = tsk_malloc(self->nodes.num_rows * sizeof(*num_edges)); - if (num_edges == NULL) { + // need to do this so tsk_safe_free works on it if it's not initialized + memset(&edges, 0, sizeof(edges)); + + num_children = tsk_malloc(self->nodes.num_rows * sizeof(*num_children)); + if (num_children == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - tsk_memset(num_edges, 0x00, self->nodes.num_rows * sizeof(*num_edges)); + tsk_memset(num_children, 0x00, self->nodes.num_rows * sizeof(*num_children)); - ret = tsk_edge_table_copy(&self->edges, &edges, 0); + ret = tsk_blkalloc_init(&edge_list_heap, 8192); if (ret != 0) { goto out; } - ret = tsk_edge_table_clear(&self->edges); - if (ret != 0) { + + if (!tsk_table_collection_has_index(self, 0)) { + ret = TSK_ERR_TABLES_NOT_INDEXED; goto out; } - ret = tsk_blkalloc_init(&edge_list_heap, 8192); + ret = tsk_edge_table_copy(&self->edges, &edges, 0); + if (ret != 0) { + goto out; + } + ret = tsk_edge_table_clear(&self->edges); if (ret != 0) { goto out; } + if (forwards) { I = self->indexes.edge_insertion_order; O = self->indexes.edge_removal_order; @@ -13020,8 +13033,8 @@ forward_extend(tsk_table_collection_t *self, int direction) y->next = x; } edges_out_tail = x; - num_edges[edges.parent[O[tk]]] -= 1; - num_edges[edges.child[O[tk]]] -= 1; + num_children[edges.parent[O[tk]]] -= 1; + num_children[edges.child[O[tk]]] -= 1; tk += sign_int; } while (((tj < M) && (tj >= 0)) && (near_side[I[tj]] == here)) { @@ -13038,8 +13051,8 @@ forward_extend(tsk_table_collection_t *self, int direction) y->next = x; } edges_in_tail = x; - num_edges[edges.parent[I[tj]]] += 1; - num_edges[edges.child[I[tj]]] += 1; + num_children[edges.parent[I[tj]]] += 1; + num_children[edges.child[I[tj]]] += 1; tj += sign_int; } there = forwards ? self->sequence_length : 0; @@ -13067,7 +13080,7 @@ forward_extend(tsk_table_collection_t *self, int direction) if (!ex2->extended) { e2 = ex2->edge; if ((edges.parent[e1] == edges.child[e2]) - && (num_edges[edges.child[e2]] == 0)) { + && (num_children[edges.child[e2]] == 0)) { for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { e_in = ex_in->edge; @@ -13077,16 +13090,10 @@ forward_extend(tsk_table_collection_t *self, int direction) ex1->extended = true; ex2->extended = true; ex_in->extended = true; - if (forwards) { - edges.right[e1] = there; - edges.right[e2] = there; - edges.left[e_in] = there; - } else { - edges.left[e1] = there; - edges.left[e2] = there; - edges.right[e_in] = there; - } - num_edges[edges.parent[e1]] += 2; + far_side[e1] = there; + far_side[e2] = there; + near_side[e_in] = there; + num_children[edges.parent[e1]] += 2; } } } @@ -13118,7 +13125,7 @@ forward_extend(tsk_table_collection_t *self, int direction) out: tsk_blkalloc_free(&edge_list_heap); - tsk_safe_free(num_edges); + tsk_safe_free(num_children); tsk_edge_table_free(&edges); return ret; } @@ -13131,14 +13138,22 @@ tsk_table_collection_extend_edges( tsk_size_t last_num_edges; last_num_edges = self->edges.num_rows; for (int j = 0; j <= max_iter; j++) { - forward_extend(self, TSK_DIR_FORWARD); - forward_extend(self, TSK_DIR_REVERSE); + ret = forward_extend(self, TSK_DIR_FORWARD); + if (ret != 0) { + goto out; + } + ret = forward_extend(self, TSK_DIR_REVERSE); + if (ret != 0) { + goto out; + } if (self->edges.num_rows == last_num_edges) { break; } else { last_num_edges = self->edges.num_rows; } } + +out: return ret; } diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 821a77be41..43b382c4bb 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -8534,7 +8534,7 @@ def py_extend_edges(self, ts, max_iter=10): return ts def _extend(self, ts, forwards=True): - num_edges = np.full(ts.num_nodes, 0) + num_children = np.full(ts.num_nodes, 0) t = ts.tables edges = t.edges.copy() @@ -8571,14 +8571,14 @@ def _extend(self, ts, forwards=True): while (tk < M) and (far_edge[O[tk]] == here): edges_out.append([O[tk], False]) - num_edges[edges.parent[O[tk]]] -= 1 - num_edges[edges.child[O[tk]]] -= 1 + num_children[edges.parent[O[tk]]] -= 1 + num_children[edges.child[O[tk]]] -= 1 tk += 1 while (tj < M) and (near_edge[I[tj]] == here): edges_in.append([I[tj], False]) - num_edges[edges.parent[I[tj]]] += 1 - num_edges[edges.child[I[tj]]] += 1 + num_children[edges.parent[I[tj]]] += 1 + num_children[edges.child[I[tj]]] += 1 tj += 1 there = ts.sequence_length if forwards else 0 @@ -8592,7 +8592,7 @@ def _extend(self, ts, forwards=True): there = max(there, far_edge[O[tk]]) if tj < M: there = max(there, near_edge[I[tj]]) - assert np.all(num_edges >= 0) + assert np.all(num_children >= 0) for ex1 in edges_out: if not ex1[1]: e1 = ex1[0] @@ -8602,7 +8602,7 @@ def _extend(self, ts, forwards=True): # the new tree e2 = ex2[0] if (edges.parent[e1] == edges.child[e2]) and ( - num_edges[edges.child[e2]] == 0 + num_children[edges.child[e2]] == 0 ): for ex_in in edges_in: e_in = ex_in[0] @@ -8622,9 +8622,9 @@ def _extend(self, ts, forwards=True): new_left[e1] = there new_left[e2] = there new_right[e_in] = there - # amend num_edges: the intermediate + # amend num_children: the intermediate # node has 2 edges instead of 0 - num_edges[edges.parent[e1]] += 2 + num_children[edges.parent[e1]] += 2 # cleanup at end of loop here = there @@ -8708,9 +8708,17 @@ def test_runs(self): self.verify_extend_edges(ts) def test_simple_ex(self): - # this is an example by hand where you need to go forwards *and* backwards - # TODO actually test that to get everythign you have to! - # note that the test above might only test the forward pass + # An example where you need to go forwards *and* backwards: + # 7 and 8 should be extended to the whole sequence + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # | | 7 | | 8 | | + # | | ++-+ | | +-++ | | + # 4 5 4 | | 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + # node_times = { 0: 0, 1: 0, @@ -8753,6 +8761,17 @@ def test_simple_ex(self): edges.add_row(parent=p, child=c, left=l, right=r) tables.sort() ts = tables.tree_sequence() + tables.extend_edges() + ets = tables.tree_sequence() + assert ts.num_edges == 18 + assert ets.num_edges == 13 + for t in ets.trees(): + assert 7 in t.nodes() + assert 8 in t.nodes() + assert t.parent(4) == 7 + assert t.parent(7) == 6 + assert t.parent(5) == 8 + assert t.parent(8) == 6 self.verify_extend_edges(ts) def test_extend_edges(self): From eb1fe3281a599dd8b1473595d780842b307c0d80 Mon Sep 17 00:00:00 2001 From: peter Date: Fri, 4 Aug 2023 08:05:48 -0700 Subject: [PATCH 75/84] factor out append --- c/tskit/tables.c | 43 ++++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 3dff91728e..cc0345a977 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12893,22 +12893,31 @@ typedef struct _edge_list_t { struct _edge_list_t *next; } edge_list_t; -static edge_list_t *TSK_WARN_UNUSED -extend_edges_alloc_entry(tsk_blkalloc_t *heap, tsk_id_t edge) +static int +extend_edges_append_entry( + edge_list_t **head, edge_list_t **tail, tsk_blkalloc_t *heap, tsk_id_t edge) { + int ret = 0; edge_list_t *x = NULL; x = tsk_blkalloc_get(heap, sizeof(*x)); if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; goto out; } - tsk_bug_assert(edge >= 0); x->edge = edge; x->extended = false; x->next = NULL; + + if (*tail == NULL) { + *head = x; + } else { + (*tail)->next = x; + } + *tail = x; out: - return x; + return ret; } static void @@ -12953,7 +12962,7 @@ forward_extend(tsk_table_collection_t *self, int direction) tsk_blkalloc_t edge_list_heap; edge_list_t *edges_in_head, *edges_in_tail; edge_list_t *edges_out_head, *edges_out_tail; - edge_list_t *x, *y, *ex1, *ex2, *ex_in; + edge_list_t *ex1, *ex2, *ex_in; tsk_edge_table_t edges; tsk_edge_t edge; double sign, here, there; @@ -13021,36 +13030,24 @@ forward_extend(tsk_table_collection_t *self, int direction) while (((tk < M) && (tk >= 0)) && (far_side[O[tk]] == here)) { // add edge tk to pending_out - x = extend_edges_alloc_entry(&edge_list_heap, O[tk]); - if (x == NULL) { + ret = extend_edges_append_entry( + &edges_out_head, &edges_out_tail, &edge_list_heap, O[tk]); + if (ret != 0) { ret = TSK_ERR_NO_MEMORY; goto out; } - if (edges_out_tail == NULL) { - edges_out_head = x; - } else { - y = edges_out_tail; - y->next = x; - } - edges_out_tail = x; num_children[edges.parent[O[tk]]] -= 1; num_children[edges.child[O[tk]]] -= 1; tk += sign_int; } while (((tj < M) && (tj >= 0)) && (near_side[I[tj]] == here)) { // add edge tj to pending_in - x = extend_edges_alloc_entry(&edge_list_heap, I[tj]); - if (x == NULL) { + ret = extend_edges_append_entry( + &edges_in_head, &edges_in_tail, &edge_list_heap, I[tj]); + if (ret != 0) { ret = TSK_ERR_NO_MEMORY; goto out; } - if (edges_in_tail == NULL) { - edges_in_head = x; - } else { - y = edges_in_tail; - y->next = x; - } - edges_in_tail = x; num_children[edges.parent[I[tj]]] += 1; num_children[edges.child[I[tj]]] += 1; tj += sign_int; From 0f0cf3dd93d20a70a1707221ad54017e96b8017f Mon Sep 17 00:00:00 2001 From: peter Date: Fri, 4 Aug 2023 08:34:27 -0700 Subject: [PATCH 76/84] fixup of tests --- python/tests/test_topology.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 43b382c4bb..0a0988bf92 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -8654,12 +8654,12 @@ def verify_extend_edges(self, ts): for e in ets.edges(): # e should be in old_edges, - # but with expanded limits + # but with modified limits k = (e.parent, e.child) assert k in old_edges overlaps = False for (l, r) in old_edges[k]: - if (l >= e.left) and (r <= e.right): + if (l <= e.right) and (r >= e.left): overlaps = True assert overlaps @@ -8774,8 +8774,20 @@ def test_simple_ex(self): assert t.parent(8) == 6 self.verify_extend_edges(ts) - def test_extend_edges(self): - tables = wf.wf_sim(5, 20, deep_history=False, seed=3) + def test_wright_fisher_trees(self): + tables = wf.wf_sim(N=5, ngens=20, deep_history=False, seed=3) tables.sort() ts = tables.tree_sequence().simplify() self.verify_extend_edges(ts) + + def test_wright_fisher_trees_unsimplified(self): + tables = wf.wf_sim(N=6, ngens=22, deep_history=False, seed=4) + tables.sort() + ts = tables.tree_sequence() + self.verify_extend_edges(ts) + + def test_wright_fisher_trees_with_history(self): + tables = wf.wf_sim(N=8, ngens=15, deep_history=True, seed=5) + tables.sort() + ts = tables.tree_sequence() + self.verify_extend_edges(ts) From 8973c7b86ada7a19decd9f88031ea045dfd95e6a Mon Sep 17 00:00:00 2001 From: peter Date: Fri, 4 Aug 2023 09:27:07 -0700 Subject: [PATCH 77/84] docstrings --- c/tskit/tables.h | 16 +++++++++++++++- python/tskit/tables.py | 6 +++++- python/tskit/trees.py | 36 ++++++++++++++++++++++-------------- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 4336bc1249..d815a322b7 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -4362,7 +4362,21 @@ int tsk_table_collection_simplify(tsk_table_collection_t *self, const tsk_id_t * /** @brief Extends edges -TODO DOCUMENT +Modifies the tables in place so that the span covered by ancestral nodes +is "extended" to regions of the genome according to the following rule: +If an ancestral segment corresponding to node `n` has parent `p` and +child `c` on some portion of the genome, and on an adjacent segment of +genome `p` is the immediate parent of `c`, then `n` is inserted into the +edge from `p` to `c`. This involves extending the span of the edges +from `p` to `n` and `n` to `c` and reducing the span of the edge from +`p` to `c`. Since the latter edge may be removed entirely, this process +reduces (or at least does not increase) the number of edges in the tree +sequence. + +The method works by iterating over the genome to look for edges that can +be extended in this way; the maximum number of such iterations is +controlled by ``max_iter``. + @rst diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 9a95e3e2b8..10ae3a594a 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4039,7 +4039,11 @@ def drop_index(self): def extend_edges(self, max_iter=10): """ - TODO DOCUMENT + Modifies the tables in place by applying the operation described + in :meth:`TreeSequence.extend_edges`. + + :param int max_iters: The maximum number of iterations over the tree + sequence. Defaults to 10. """ self._ll_tables.extend_edges(max_iter) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 74c010ff9b..36536f37c1 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6904,21 +6904,29 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None): def extend_edges(self, max_iter=10): """ - TODO: make this better Returns a new tree sequence in which the span covered by ancestral nodes - is "extended" to regions of the genome over which their ancestry is - unambiguous, which occurs if the node is an intermediate in a chain - of ancestry that also exists in neighboring regions of the genome. - While iterating over the tree sequence, in each tree, we identify - connecting edges with unary nodes. If an equivalent edge segment - exists in the next tree without that unary node, we extend the - connecting edges from the previous tree into the next tree, - subsequently adding that unary node to the tree. This in turn reduces - the length of the edge just removed from the next tree, and if its - length becomes zero it is removed from the edge table. - - :param int max_iters: The maximum number of forward-and-backward - iterations over the tree sequence. Defaults to 10. + is "extended" to regions of the genome according to the following rule: + If an ancestral segment corresponding to node `n` has parent `p` and + child `c` on some portion of the genome, and on an adjacent segment of + genome `p` is the immediate parent of `c`, then `n` is inserted into the + edge from `p` to `c`. This involves extending the span of the edges + from `p` to `n` and `n` to `c` and reducing the span of the edge from + `p` to `c`. Since the latter edge may be removed entirely, this process + reduces (or at least does not increase) the number of edges in the tree + sequence. + + The method works by iterating over the genome to look for edges that can + be extended in this way; the maximum number of such iterations is + controlled by ``max_iter``. + + The rationale is that we know that `n` carries a portion of the segment + of ancestral genome inherited by `c` from `p`, and so likely carries + the *entire* inherited segment (since the implication otherwise would + be that distinct recombined segments were passed down separately from + `p` to `c`). + + :param int max_iters: The maximum number of iterations over the tree + sequence. Defaults to 10. :return: A new tree sequence with unary nodes extended. :rtype: tskit.TreeSequence From 18f10605e54a4fa3e56c3aaf1b68682b2c3e5f9f Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 14 Aug 2023 21:08:13 -0700 Subject: [PATCH 78/84] changelog, docs and warning in docstring --- docs/python-api.md | 1 + python/CHANGELOG.rst | 4 ++++ python/tskit/trees.py | 7 +++++++ 3 files changed, 12 insertions(+) diff --git a/docs/python-api.md b/docs/python-api.md index a8236daadf..20a2d541b4 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -268,6 +268,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.trim TreeSequence.split_edges TreeSequence.decapitate + TreeSequence.extend_edges ``` (sec_python_api_tree_sequences_ibd)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index f529d930d4..d2e19f2be6 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -15,6 +15,10 @@ - Add ``asdict`` to all dataclasses. These are returned when you access a row or other tree sequence object. (:user:`benjeffery`, :pr:`2759`, :issue:`2719`) +- Add ``TreeSequence.extend_edges`` method that extends ancestral haplotypes + using recombination information, leading to unary nodes in many trees and + fewer edges. (:user:`petrelharp`, :user:`hfr1tz3`, :user:`avabamf`, :pr:`2651`) + -------------------- [0.5.5] - 2023-05-17 -------------------- diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 36536f37c1..836cbafee5 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6915,6 +6915,9 @@ def extend_edges(self, max_iter=10): reduces (or at least does not increase) the number of edges in the tree sequence. + *Note:* this is a somewhat experimental operation, and is probably not + what you are looking for. + The method works by iterating over the genome to look for edges that can be extended in this way; the maximum number of such iterations is controlled by ``max_iter``. @@ -6925,6 +6928,10 @@ def extend_edges(self, max_iter=10): be that distinct recombined segments were passed down separately from `p` to `c`). + The method will not affect the marginal trees (so, following up with + `simplify` will recover the original tree sequence, possibly with edges + in a different order). + :param int max_iters: The maximum number of iterations over the tree sequence. Defaults to 10. From e88ebbbdb6f51c878a1e150f6679913b26d911f5 Mon Sep 17 00:00:00 2001 From: peter Date: Wed, 16 Aug 2023 14:17:54 -0700 Subject: [PATCH 79/84] ts_position-style alg passes tests and agrees with C --- c/tskit/tables.c | 37 +++-- c/tskit/trees.c | 265 ++++++++++++++++++++++++++++++++ python/tests/test_topology.py | 278 ++++++++++++++++++++++------------ python/tskit/tables.py | 5 +- 4 files changed, 470 insertions(+), 115 deletions(-) diff --git a/c/tskit/tables.c b/c/tskit/tables.c index cc0345a977..a1b6b88bad 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12885,6 +12885,10 @@ tsk_table_collection_add_and_remap_node(tsk_table_collection_t *self, return ret; } +/* ======================================================== * + * Extend edges + * ======================================================== */ + typedef struct _edge_list_t { tsk_id_t edge; // the `extended` flags records whether we have decided to extend @@ -12946,14 +12950,14 @@ remove_unextended(edge_list_t **head, edge_list_t **tail) } static int -forward_extend(tsk_table_collection_t *self, int direction) +do_extend(tsk_table_collection_t *self, int direction) { // Note: this modifies the edge table, but it does this by (a) removing // some edges, and (b) extending left/right endpoints of others, // while keeping order the same, and so this maintains sortedness // (so, there is no need to sort afterwards). int ret = 0; - tsk_id_t *num_children; + tsk_id_t *degree; tsk_id_t tj, tk, ret_id; tsk_id_t e1, e2, e_in; tsk_id_t *I, *O; @@ -12972,12 +12976,12 @@ forward_extend(tsk_table_collection_t *self, int direction) // need to do this so tsk_safe_free works on it if it's not initialized memset(&edges, 0, sizeof(edges)); - num_children = tsk_malloc(self->nodes.num_rows * sizeof(*num_children)); - if (num_children == NULL) { + degree = tsk_malloc(self->nodes.num_rows * sizeof(*degree)); + if (degree == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - tsk_memset(num_children, 0x00, self->nodes.num_rows * sizeof(*num_children)); + tsk_memset(degree, 0x00, self->nodes.num_rows * sizeof(*degree)); ret = tsk_blkalloc_init(&edge_list_heap, 8192); if (ret != 0) { @@ -13036,8 +13040,6 @@ forward_extend(tsk_table_collection_t *self, int direction) ret = TSK_ERR_NO_MEMORY; goto out; } - num_children[edges.parent[O[tk]]] -= 1; - num_children[edges.child[O[tk]]] -= 1; tk += sign_int; } while (((tj < M) && (tj >= 0)) && (near_side[I[tj]] == here)) { @@ -13048,10 +13050,17 @@ forward_extend(tsk_table_collection_t *self, int direction) ret = TSK_ERR_NO_MEMORY; goto out; } - num_children[edges.parent[I[tj]]] += 1; - num_children[edges.child[I[tj]]] += 1; tj += sign_int; } + for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { + degree[edges.parent[ex1->edge]] -= 1; + degree[edges.child[ex1->edge]] -= 1; + } + for (ex1 = edges_in_head; ex1 != NULL; ex1 = ex1->next) { + degree[edges.parent[ex1->edge]] += 1; + degree[edges.child[ex1->edge]] += 1; + } + there = forwards ? self->sequence_length : 0; if (forwards) { if (tk < M) { @@ -13077,7 +13086,7 @@ forward_extend(tsk_table_collection_t *self, int direction) if (!ex2->extended) { e2 = ex2->edge; if ((edges.parent[e1] == edges.child[e2]) - && (num_children[edges.child[e2]] == 0)) { + && (degree[edges.child[e2]] == 0)) { for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { e_in = ex_in->edge; @@ -13090,7 +13099,7 @@ forward_extend(tsk_table_collection_t *self, int direction) far_side[e1] = there; far_side[e2] = there; near_side[e_in] = there; - num_children[edges.parent[e1]] += 2; + degree[edges.parent[e1]] += 2; } } } @@ -13122,7 +13131,7 @@ forward_extend(tsk_table_collection_t *self, int direction) out: tsk_blkalloc_free(&edge_list_heap); - tsk_safe_free(num_children); + tsk_safe_free(degree); tsk_edge_table_free(&edges); return ret; } @@ -13135,11 +13144,11 @@ tsk_table_collection_extend_edges( tsk_size_t last_num_edges; last_num_edges = self->edges.num_rows; for (int j = 0; j <= max_iter; j++) { - ret = forward_extend(self, TSK_DIR_FORWARD); + ret = do_extend(self, TSK_DIR_FORWARD); if (ret != 0) { goto out; } - ret = forward_extend(self, TSK_DIR_REVERSE); + ret = do_extend(self, TSK_DIR_REVERSE); if (ret != 0) { goto out; } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 8a3d0afc95..4ac90cca74 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6803,3 +6803,268 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, out: return ret; } + +/* ======================================================== * + * Extend edges + * ======================================================== */ + +typedef struct _edge_list_t { + tsk_id_t edge; + // the `extended` flags records whether we have decided to extend + // this entry to the current tree? + bool extended; + struct _edge_list_t *next; +} edge_list_t; + +static int +extend_edges_append_entry( + edge_list_t **head, edge_list_t **tail, tsk_blkalloc_t *heap, tsk_id_t edge) +{ + int ret = 0; + edge_list_t *x = NULL; + + x = tsk_blkalloc_get(heap, sizeof(*x)); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + x->edge = edge; + x->extended = false; + x->next = NULL; + + if (*tail == NULL) { + *head = x; + } else { + (*tail)->next = x; + } + *tail = x; +out: + return ret; +} + +static void +remove_unextended(edge_list_t **head, edge_list_t **tail) +{ + edge_list_t *px, *x; + + px = *head; + while (px != NULL && !px->extended) { + px = px->next; + } + *head = px; + if (px != NULL) { + px->extended = false; + x = px->next; + while (x != NULL) { + if (x->extended) { + x->extended = false; + px->next = x; + px = x; + } + x = x->next; + } + } + *tail = px; +} + +static int +do_extend(tsk_treeseq_t *self, int direction) +{ + // Note: this modifies the edge table, but it does this by (a) removing + // some edges, and (b) extending left/right endpoints of others, + // while keeping order the same, and so this maintains sortedness + // (so, there is no need to sort afterwards). + int ret = 0; + tsk_id_t *degree; + tsk_id_t tj, ret_id; + tsk_id_t e, e1, e2, e_in; + double *near_side, *far_side; + tsk_blkalloc_t edge_list_heap; + edge_list_t *edges_in_head, *edges_in_tail; + edge_list_t *edges_out_head, *edges_out_tail; + edge_list_t *ex1, *ex2, *ex_in; + tsk_edge_table_t edges; + tsk_edge_t edge; + double there, left, right; + tsk_id_t sign_int; + bool forwards = (direction == TSK_DIR_FORWARD); + tsk_tree_position_t tree_pos; + bool valid; + + // need to do this so tsk_safe_free works on it if it's not initialized + memset(&edges, 0, sizeof(edges)); + + degree = tsk_malloc(self->tables->nodes.num_rows * sizeof(*degree)); + if (degree == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tsk_memset(degree, 0x00, self->tables->nodes.num_rows * sizeof(*degree)); + + ret = tsk_blkalloc_init(&edge_list_heap, 8192); + if (ret != 0) { + goto out; + } + + ret = tsk_tree_position_init(&tree_pos, self, 0); + if (ret != 0) { + goto out; + } + + ret = tsk_edge_table_copy(&self->tables->edges, &edges, 0); + if (ret != 0) { + goto out; + } + ret = tsk_edge_table_clear(&self->tables->edges); + if (ret != 0) { + goto out; + } + + if (forwards) { + sign_int = 1; + near_side = edges.left; + far_side = edges.right; + } else { + sign_int = -1; + near_side = edges.right; + far_side = edges.left; + } + edges_in_head = NULL; + edges_in_tail = NULL; + edges_out_head = NULL; + edges_out_tail = NULL; + + if (forwards) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); + } + + while (valid) { + left = tree_pos.interval.left; + right = tree_pos.interval.right; + there = forwards ? right : left; + + // remove entries that aren't being extended/postponed + remove_unextended(&edges_in_head, &edges_in_tail); + remove_unextended(&edges_out_head, &edges_out_tail); + + for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += sign_int) { + e = tree_pos.out.order[tj]; + // add edge to pending_out + ret = extend_edges_append_entry( + &edges_out_head, &edges_out_tail, &edge_list_heap, e); + if (ret != 0) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += sign_int) { + e = tree_pos.in.order[tj]; + // add edge to pending_in + ret = extend_edges_append_entry( + &edges_in_head, &edges_in_tail, &edge_list_heap, e); + if (ret != 0) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { + degree[edges.parent[ex1->edge]] -= 1; + degree[edges.child[ex1->edge]] -= 1; + } + for (ex1 = edges_in_head; ex1 != NULL; ex1 = ex1->next) { + degree[edges.parent[ex1->edge]] += 1; + degree[edges.child[ex1->edge]] += 1; + } + + // iterate over pairs of out and in: (ex1, ex2, in) + for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { + if (!ex1->extended) { + e1 = ex1->edge; + for (ex2 = edges_out_head; ex2 != NULL; ex2 = ex2->next) { + if (!ex2->extended) { + e2 = ex2->edge; + if ((edges.parent[e1] == edges.child[e2]) + && (degree[edges.child[e2]] == 0)) { + for (ex_in = edges_in_head; ex_in != NULL; + ex_in = ex_in->next) { + e_in = ex_in->edge; + if ((edges.left[e_in] < right) + && (edges.right[e_in] > left)) { + if ((edges.child[e1] == edges.child[e_in]) + && (edges.parent[e2] == edges.parent[e_in])) { + ex1->extended = true; + ex2->extended = true; + ex_in->extended = true; + far_side[e1] = there; + far_side[e2] = there; + near_side[e_in] = there; + degree[edges.parent[e1]] += 2; + } + } + } + } + } + } + } + } + if (forwards) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); + } + } + + // done! write out new edge tables + for (tj = 0; tj < (tsk_id_t) edges.num_rows; tj++) { + tsk_edge_table_get_row(&edges, tj, &edge); + if (edge.left < edge.right) { + ret_id = tsk_edge_table_add_row(&self->tables->edges, edge.left, edge.right, + edge.parent, edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + ret = tsk_table_collection_build_index(self->tables, 0); + if (ret != 0) { + goto out; + } + +out: + tsk_blkalloc_free(&edge_list_heap); + tsk_safe_free(degree); + tsk_edge_table_free(&edges); + tsk_tree_position_free(&tree_pos); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_sequence_extend_edges( + tsk_treeseq_t *self, int max_iter, tsk_flags_t TSK_UNUSED(options)) +{ + int ret = 0; + tsk_size_t last_num_edges; + last_num_edges = self->tables->edges.num_rows; + for (int j = 0; j <= max_iter; j++) { + ret = do_extend(self, TSK_DIR_FORWARD); + if (ret != 0) { + goto out; + } + ret = do_extend(self, TSK_DIR_REVERSE); + if (ret != 0) { + goto out; + } + if (self->tables->edges.num_rows == last_num_edges) { + break; + } else { + last_num_edges = self->tables->edges.num_rows; + } + } + +out: + return ret; +} diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 0a0988bf92..da349eb3e6 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -8534,65 +8534,60 @@ def py_extend_edges(self, ts, max_iter=10): return ts def _extend(self, ts, forwards=True): - num_children = np.full(ts.num_nodes, 0) + degree = np.full(ts.num_nodes, 0) - t = ts.tables + t = ts.dump_tables() edges = t.edges.copy() t.edges.clear() - new_left = edges.left - new_right = edges.right - # edge diff stuff - M = edges.num_rows + # "here" will be left if fowards else right; + # and "there" is the other + new_left = edges.left.copy() + new_right = edges.right.copy() if forwards: - I = ts.indexes_edge_insertion_order # NOQA E741 - O = ts.indexes_edge_removal_order # NOQA E741 - # "here" will be left if fowards else right - here = 0 - sign = +1 - near_edge = edges.left - far_edge = edges.right + direction = 1 + # in C we can just modify these in place, but in + # python they are (silently) immutable + new_here = new_left + new_there = new_right else: - I = np.flip(ts.indexes_edge_removal_order) # NOQA E741 - O = np.flip(ts.indexes_edge_insertion_order) # NOQA E741 - here = ts.sequence_length - sign = -1 - near_edge = edges.right - far_edge = edges.left - tj = 0 - tk = 0 + direction = -1 + new_here = new_right + new_there = new_left edges_out = [] edges_in = [] - while tj < M: + tree_pos = tsutil.TreePosition(ts) + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + while valid: + left, right = tree_pos.interval + there = right if forwards else left + # clear out non-extended or postponed edges edges_out = [[e, False] for e, x in edges_out if x] edges_in = [[e, False] for e, x in edges_in if x] - while (tk < M) and (far_edge[O[tk]] == here): - edges_out.append([O[tk], False]) - num_children[edges.parent[O[tk]]] -= 1 - num_children[edges.child[O[tk]]] -= 1 - tk += 1 + for j in range( + tree_pos.out_range.start, tree_pos.out_range.stop, direction + ): + e = tree_pos.out_range.order[j] + edges_out.append([e, False]) - while (tj < M) and (near_edge[I[tj]] == here): - edges_in.append([I[tj], False]) - num_children[edges.parent[I[tj]]] += 1 - num_children[edges.child[I[tj]]] += 1 - tj += 1 + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, direction): + e = tree_pos.in_range.order[j] + edges_in.append([e, False]) - there = ts.sequence_length if forwards else 0 - if forwards: - if tk < M: - there = min(there, far_edge[O[tk]]) - if tj < M: - there = min(there, near_edge[I[tj]]) - else: - if tk < M: - there = max(there, far_edge[O[tk]]) - if tj < M: - there = max(there, near_edge[I[tj]]) - assert np.all(num_children >= 0) + for e, _ in edges_out: + degree[edges.parent[e]] -= 1 + degree[edges.child[e]] -= 1 + for e, _ in edges_in: + degree[edges.parent[e]] += 1 + degree[edges.child[e]] += 1 + + assert np.all(degree >= 0) for ex1 in edges_out: if not ex1[1]: e1 = ex1[0] @@ -8602,11 +8597,17 @@ def _extend(self, ts, forwards=True): # the new tree e2 = ex2[0] if (edges.parent[e1] == edges.child[e2]) and ( - num_children[edges.child[e2]] == 0 + degree[edges.child[e2]] == 0 ): for ex_in in edges_in: e_in = ex_in[0] - if sign * far_edge[e_in] > sign * here: + # we might have passed the interval that a + # postponed edge in covers, in which case + # we should skip it + if ( + new_left[e_in] < right + and new_right[e_in] > left + ): if ( edges.child[e1] == edges.child[e_in] and edges.parent[e2] == edges.parent[e_in] @@ -8614,19 +8615,17 @@ def _extend(self, ts, forwards=True): ex1[1] = True ex2[1] = True ex_in[1] = True - if forwards: - new_right[e1] = there - new_right[e2] = there - new_left[e_in] = there - else: - new_left[e1] = there - new_left[e2] = there - new_right[e_in] = there - # amend num_children: the intermediate + new_there[e1] = there + new_there[e2] = there + new_here[e_in] = there + # amend degree: the intermediate # node has 2 edges instead of 0 - num_children[edges.parent[e1]] += 2 - # cleanup at end of loop - here = there + degree[edges.parent[e1]] += 2 + # end of loop, next tree + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() for j in range(edges.num_rows): left = new_left[j] @@ -8637,8 +8636,14 @@ def _extend(self, ts, forwards=True): t.build_index() return t.tree_sequence() - def verify_extend_edges(self, ts): - ets = ts.extend_edges() + def verify_extend_edges(self, ts, max_iter=10): + # This can still fail for various weird examples: + # for instance, if adjacent trees have + # a <- b <- c <- d and a <- d (where say b was + # inserted in an earlier pass), then b and c + # won't be extended + + ets = ts.extend_edges(max_iter=max_iter) assert ts.num_samples == ets.num_samples assert ts.num_nodes == ets.num_nodes assert ts.num_edges >= ets.num_edges @@ -8654,59 +8659,119 @@ def verify_extend_edges(self, ts): for e in ets.edges(): # e should be in old_edges, - # but with modified limits + # but with modified limits: + # USUALLY overlapping limits, but + # not necessarily after more than one pass k = (e.parent, e.child) assert k in old_edges - overlaps = False - for (l, r) in old_edges[k]: - if (l <= e.right) and (r >= e.left): - overlaps = True - assert overlaps - - chains = [] - for _, tt, _ett in ts.coiterate(ets): - this_chains = [] - for a in tt.nodes(): - b = tt.parent(a) - if b != tskit.NULL: - c = tt.parent(b) - if c != tskit.NULL: - this_chains.append((a, b, c)) - chains.append(this_chains) - - for k, (_, tt, ett) in enumerate(ts.coiterate(ets)): - for j in (k - 1, k + 1): - if j < 0 or j >= len(chains): - next - else: - this_chains = chains[j] - for a, b, c in this_chains: - if a in tt.nodes() and tt.parent(a) == c and b not in tt.nodes(): - # the relationship a <- b <- c should still be in the tree, - # although maybe they aren't direct parent-offspring - assert a in ett.nodes() + if max_iter == 1: + overlaps = False + for (l, r) in old_edges[k]: + if (l <= e.right) and (r >= e.left): + overlaps = True + assert overlaps + + if max_iter > 1: + chains = [] + for _, tt, ett in ts.coiterate(ets): + this_chains = [] + for a in tt.nodes(): + assert a in ett.nodes() + b = tt.parent(a) + if b != tskit.NULL: + c = tt.parent(b) + if c != tskit.NULL: + this_chains.append((a, b, c)) assert b in ett.nodes() - assert c in ett.nodes() + # the relationship a <- b should still be in the tree p = a - while p != tskit.NULL: - if p == b: - break + while p != tskit.NULL and p != b: p = ett.parent(p) assert p == b - while p != tskit.NULL: - if p == c: - break - p = ett.parent(p) - assert p == c + chains.append(this_chains) + + extended_ac = {} + not_extended_ac = {} + extended_ab = {} + not_extended_ab = {} + for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): + for j in (k - 1, k + 1): + if j < 0 or j >= len(chains): + continue + else: + this_chains = chains[j] + for a, b, c in this_chains: + if ( + a in tt.nodes() + and tt.parent(a) == c + and b not in tt.nodes() + ): + # the relationship a <- b <- c should still be in the tree, + # although maybe they aren't direct parent-offspring + # UNLESS we've got an ambiguous case, where on the opposite + # side of the interval a chain a <- b' <- c got extended + # into the region OR b got inserted into another chain + assert a in ett.nodes() + assert c in ett.nodes() + if b not in ett.nodes(): + if (a, c) not in not_extended_ac: + not_extended_ac[(a, c)] = [] + not_extended_ac[(a, c)].append(interval) + else: + if (a, c) not in extended_ac: + extended_ac[(a, c)] = [] + extended_ac[(a, c)].append(interval) + p = a + while p != tskit.NULL and p != b: + p = ett.parent(p) + if p != b: + if (a, b) not in not_extended_ab: + not_extended_ab[(a, b)] = [] + not_extended_ab[(a, b)].append(interval) + else: + if (a, b) not in extended_ab: + extended_ab[(a, b)] = [] + extended_ab[(a, b)].append(interval) + while p != tskit.NULL and p != c: + p = ett.parent(p) + assert p == c + for a, c in not_extended_ac: + # check that a <- ... <- c has been extended somewhere + # although not necessarily from an adjacent segment + assert (a, c) in extended_ac + for interval in not_extended_ac[(a, c)]: + ett = ets.at(interval.left) + assert ett.parent(a) != c + for k in not_extended_ab: + assert k in extended_ab + for interval in not_extended_ab[k]: + assert interval in extended_ab[k] # finally, compare C version to python version - py_et = self.py_extend_edges(ts).dump_tables() + py_et = self.py_extend_edges(ts, max_iter=max_iter).dump_tables() et = ets.dump_tables() + # py_et.tree_sequence().dump("py.trees") + # ets.dump("c.trees") et.assert_equals(py_et) + def test_extend_edges_errors(self): + t = msprime.simulate(5, random_seed=126).dump_tables() + t.drop_index() + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_TABLES_NOT_INDEXED"): + t.extend_edges() + def test_runs(self): ts = msprime.simulate(5, random_seed=126) self.verify_extend_edges(ts) + def test_max_iter(self): + ts = msprime.simulate(5, random_seed=126) + with pytest.raises(ValueError, match="max_iter"): + ets = ts.extend_edges(max_iter=0) + ets = ts.extend_edges(max_iter=1) + et = ets.extend_edges(max_iter=1).dump_tables() + eet = ets.extend_edges(max_iter=2).dump_tables() + eet.assert_equals(et) + def test_simple_ex(self): # An example where you need to go forwards *and* backwards: # 7 and 8 should be extended to the whole sequence @@ -8777,17 +8842,30 @@ def test_simple_ex(self): def test_wright_fisher_trees(self): tables = wf.wf_sim(N=5, ngens=20, deep_history=False, seed=3) tables.sort() - ts = tables.tree_sequence().simplify() + tables.simplify() + ts = tables.tree_sequence() + # self.verify_extend_edges(ts, max_iter=1) self.verify_extend_edges(ts) def test_wright_fisher_trees_unsimplified(self): tables = wf.wf_sim(N=6, ngens=22, deep_history=False, seed=4) tables.sort() ts = tables.tree_sequence() + # self.verify_extend_edges(ts, max_iter=1) self.verify_extend_edges(ts) def test_wright_fisher_trees_with_history(self): tables = wf.wf_sim(N=8, ngens=15, deep_history=True, seed=5) tables.sort() + tables.simplify() ts = tables.tree_sequence() + # self.verify_extend_edges(ts, max_iter=1) self.verify_extend_edges(ts) + + # def test_bigger_wright_fisher(self): + # tables = wf.wf_sim(N=50, ngens=15, deep_history=True, seed=6) + # tables.sort() + # tables.simplify() + # ts = tables.tree_sequence() + # self.verify_extend_edges(ts, max_iter=1) + # self.verify_extend_edges(ts, max_iter=200) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 10ae3a594a..a0cf6d16a3 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4042,9 +4042,12 @@ def extend_edges(self, max_iter=10): Modifies the tables in place by applying the operation described in :meth:`TreeSequence.extend_edges`. - :param int max_iters: The maximum number of iterations over the tree + :param int max_iter: The maximum number of iterations over the tree sequence. Defaults to 10. """ + max_iter = int(max_iter) + if max_iter <= 0: + raise ValueError("max_iter must be a positive integer.") self._ll_tables.extend_edges(max_iter) def subset( From dbd70330db4c42cb2825b19ac48efa2ddf5c46c4 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 17 Aug 2023 08:58:48 -0700 Subject: [PATCH 80/84] working!!! --- c/tests/test_tables.c | 103 ------------- c/tests/test_trees.c | 106 +++++++++++++ c/tskit/tables.c | 278 ---------------------------------- c/tskit/tables.h | 32 ---- c/tskit/trees.c | 65 +++++--- c/tskit/trees.h | 32 ++++ python/_tskitmodule.c | 74 +++++---- python/tests/test_topology.py | 19 +-- python/tskit/tables.py | 13 -- python/tskit/trees.py | 8 +- 10 files changed, 235 insertions(+), 495 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 58386894ab..6de6675ff6 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -2507,107 +2507,6 @@ test_edge_table_copy_semantics(void) tsk_treeseq_free(&ts); } -static void -test_extend_edges_simple(void) -{ - int ret; - tsk_table_collection_t tables, tables_copy; - - const char *nodes_ex = "1 0 -1 -1\n" - "1 0 -1 -1\n" - "0 2.0 -1 -1\n"; - const char *edges_ex = "0 10 2 0\n" - "0 10 2 1\n"; - - ret = tsk_table_collection_init(&tables, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - tables.sequence_length = 10; - - parse_nodes(nodes_ex, &tables.nodes); - CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 3); - parse_edges(edges_ex, &tables.edges); - CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 2); - - tsk_table_collection_drop_index(&tables, 0); - ret = tsk_table_collection_extend_edges(&tables, 10, 0); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TABLES_NOT_INDEXED); - - tsk_table_collection_build_index(&tables, 0); - tsk_table_collection_copy(&tables, &tables_copy, 0); - ret = tsk_table_collection_extend_edges(&tables, 10, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_table_collection_equals(&tables, &tables_copy, 0)); - - tsk_table_collection_free(&tables); - tsk_table_collection_free(&tables_copy); -} - -static void -test_extend_edges(void) -{ - int ret; - tsk_table_collection_t tables; - /* 7 and 8 should be extended to the whole sequence - - 6 6 6 6 - +-+-+ +-+-+ +-+-+ +-+-+ - | | 7 | | 8 | | - | | ++-+ | | +-++ | | - 4 5 4 | | 4 | 5 4 5 - +++ +++ +++ | | | | +++ +++ +++ - 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - */ - - const char *nodes_ex = "1 0 -1 -1\n" - "1 0 -1 -1\n" - "1 0 -1 -1\n" - "1 0 -1 -1\n" - "0 1.0 -1 -1\n" - "0 1.0 -1 -1\n" - "0 3.0 -1 -1\n" - "0 2.0 -1 -1\n" - "0 2.0 -1 -1\n"; - // l, r, p, c - const char *edges_ex = "0 10 4 0\n" - "0 5 4 1\n" - "7 10 4 1\n" - "0 2 5 2\n" - "5 10 5 2\n" - "0 2 5 3\n" - "5 10 5 3\n" - "0 2 6 4\n" - "5 10 6 4\n" - "0 2 6 5\n" - "7 10 6 5\n" - "2 5 6 3\n" - "2 5 6 7\n" - "5 7 6 8\n" - "2 5 7 2\n" - "2 5 7 4\n" - "5 7 8 1\n" - "5 7 8 5\n"; - - ret = tsk_table_collection_init(&tables, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - tables.sequence_length = 10; - - parse_nodes(nodes_ex, &tables.nodes); - CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); - parse_edges(edges_ex, &tables.edges); - CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 18); - ret = tsk_table_collection_sort(&tables, NULL, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_table_collection_build_index(&tables, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - - ret = tsk_table_collection_extend_edges(&tables, 10, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); - CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 13); - - tsk_table_collection_free(&tables); -} - static void test_edge_table_squash(void) { @@ -11707,8 +11606,6 @@ main(int argc, char **argv) { "test_simplify_tables_drops_indexes", test_simplify_tables_drops_indexes }, { "test_simplify_empty_tables", test_simplify_empty_tables }, { "test_simplify_metadata", test_simplify_metadata }, - { "test_extend_edges_simple", test_extend_edges_simple }, - { "test_extend_edges", test_extend_edges }, { "test_link_ancestors_no_edges", test_link_ancestors_no_edges }, { "test_link_ancestors_input_errors", test_link_ancestors_input_errors }, { "test_link_ancestors_single_tree", test_link_ancestors_single_tree }, diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 63b7292322..a41efa6884 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -8212,6 +8212,110 @@ test_split_edges_errors(void) tsk_treeseq_free(&ts); } +static void +test_extend_edges_simple(void) +{ + int ret; + tsk_treeseq_t ts, ets; + tsk_table_collection_t tables; + + const char *nodes_ex = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 2.0 -1 -1\n"; + const char *edges_ex = "0 10 2 0\n" + "0 10 2 1\n"; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 10; + + parse_nodes(nodes_ex, &tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 3); + parse_edges(edges_ex, &tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 2); + + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, ets.tables, 0)); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); + tsk_treeseq_free(&ets); +} + +static void +test_extend_edges(void) +{ + int ret; + tsk_treeseq_t ts, ets; + tsk_table_collection_t tables; + /* 7 and 8 should be extended to the whole sequence + + 6 6 6 6 + +-+-+ +-+-+ +-+-+ +-+-+ + | | 7 | | 8 | | + | | ++-+ | | +-++ | | + 4 5 4 | | 4 | 5 4 5 + +++ +++ +++ | | | | +++ +++ +++ + 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + */ + + const char *nodes_ex = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 3.0 -1 -1\n" + "0 2.0 -1 -1\n" + "0 2.0 -1 -1\n"; + // l, r, p, c + const char *edges_ex = "0 10 4 0\n" + "0 5 4 1\n" + "7 10 4 1\n" + "0 2 5 2\n" + "5 10 5 2\n" + "0 2 5 3\n" + "5 10 5 3\n" + "0 2 6 4\n" + "5 10 6 4\n" + "0 2 6 5\n" + "7 10 6 5\n" + "2 5 6 3\n" + "2 5 6 7\n" + "5 7 6 8\n" + "2 5 7 2\n" + "2 5 7 4\n" + "5 7 8 1\n" + "5 7 8 5\n"; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 10; + + parse_nodes(nodes_ex, &tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); + parse_edges(edges_ex, &tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 18); + ret = tsk_table_collection_sort(&tables, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(ets.tables->nodes.num_rows, 9); + CU_ASSERT_EQUAL_FATAL(ets.tables->edges.num_rows, 13); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); + tsk_treeseq_free(&ets); +} + static void test_init_take_ownership_no_edge_metadata(void) { @@ -8431,6 +8535,8 @@ main(int argc, char **argv) { "test_split_edges_no_populations", test_split_edges_no_populations }, { "test_split_edges_populations", test_split_edges_populations }, { "test_split_edges_errors", test_split_edges_errors }, + { "test_extend_edges_simple", test_extend_edges_simple }, + { "test_extend_edges", test_extend_edges }, { "test_init_take_ownership_no_edge_metadata", test_init_take_ownership_no_edge_metadata }, { NULL, NULL }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index a1b6b88bad..8eea85f5ad 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12885,284 +12885,6 @@ tsk_table_collection_add_and_remap_node(tsk_table_collection_t *self, return ret; } -/* ======================================================== * - * Extend edges - * ======================================================== */ - -typedef struct _edge_list_t { - tsk_id_t edge; - // the `extended` flags records whether we have decided to extend - // this entry to the current tree? - bool extended; - struct _edge_list_t *next; -} edge_list_t; - -static int -extend_edges_append_entry( - edge_list_t **head, edge_list_t **tail, tsk_blkalloc_t *heap, tsk_id_t edge) -{ - int ret = 0; - edge_list_t *x = NULL; - - x = tsk_blkalloc_get(heap, sizeof(*x)); - if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - - x->edge = edge; - x->extended = false; - x->next = NULL; - - if (*tail == NULL) { - *head = x; - } else { - (*tail)->next = x; - } - *tail = x; -out: - return ret; -} - -static void -remove_unextended(edge_list_t **head, edge_list_t **tail) -{ - edge_list_t *px, *x; - - px = *head; - while (px != NULL && !px->extended) { - px = px->next; - } - *head = px; - if (px != NULL) { - px->extended = false; - x = px->next; - while (x != NULL) { - if (x->extended) { - x->extended = false; - px->next = x; - px = x; - } - x = x->next; - } - } - *tail = px; -} - -static int -do_extend(tsk_table_collection_t *self, int direction) -{ - // Note: this modifies the edge table, but it does this by (a) removing - // some edges, and (b) extending left/right endpoints of others, - // while keeping order the same, and so this maintains sortedness - // (so, there is no need to sort afterwards). - int ret = 0; - tsk_id_t *degree; - tsk_id_t tj, tk, ret_id; - tsk_id_t e1, e2, e_in; - tsk_id_t *I, *O; - const tsk_id_t M = (tsk_id_t) self->edges.num_rows; - double *near_side, *far_side; - tsk_blkalloc_t edge_list_heap; - edge_list_t *edges_in_head, *edges_in_tail; - edge_list_t *edges_out_head, *edges_out_tail; - edge_list_t *ex1, *ex2, *ex_in; - tsk_edge_table_t edges; - tsk_edge_t edge; - double sign, here, there; - tsk_id_t sign_int; - bool forwards = (direction == TSK_DIR_FORWARD); - - // need to do this so tsk_safe_free works on it if it's not initialized - memset(&edges, 0, sizeof(edges)); - - degree = tsk_malloc(self->nodes.num_rows * sizeof(*degree)); - if (degree == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tsk_memset(degree, 0x00, self->nodes.num_rows * sizeof(*degree)); - - ret = tsk_blkalloc_init(&edge_list_heap, 8192); - if (ret != 0) { - goto out; - } - - if (!tsk_table_collection_has_index(self, 0)) { - ret = TSK_ERR_TABLES_NOT_INDEXED; - goto out; - } - - ret = tsk_edge_table_copy(&self->edges, &edges, 0); - if (ret != 0) { - goto out; - } - ret = tsk_edge_table_clear(&self->edges); - if (ret != 0) { - goto out; - } - - if (forwards) { - I = self->indexes.edge_insertion_order; - O = self->indexes.edge_removal_order; - here = 0; - sign = 1; - sign_int = 1; - near_side = edges.left; - far_side = edges.right; - tj = 0; - tk = 0; - } else { - O = self->indexes.edge_insertion_order; - I = self->indexes.edge_removal_order; - here = self->sequence_length; - sign = -1; - sign_int = -1; - near_side = edges.right; - far_side = edges.left; - tj = M - 1; - tk = M - 1; - } - edges_in_head = NULL; - edges_in_tail = NULL; - edges_out_head = NULL; - edges_out_tail = NULL; - while ((tj < M) && (tj >= 0)) { - // remove entries that aren't being extended/postponed - remove_unextended(&edges_in_head, &edges_in_tail); - remove_unextended(&edges_out_head, &edges_out_tail); - - while (((tk < M) && (tk >= 0)) && (far_side[O[tk]] == here)) { - // add edge tk to pending_out - ret = extend_edges_append_entry( - &edges_out_head, &edges_out_tail, &edge_list_heap, O[tk]); - if (ret != 0) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tk += sign_int; - } - while (((tj < M) && (tj >= 0)) && (near_side[I[tj]] == here)) { - // add edge tj to pending_in - ret = extend_edges_append_entry( - &edges_in_head, &edges_in_tail, &edge_list_heap, I[tj]); - if (ret != 0) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tj += sign_int; - } - for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { - degree[edges.parent[ex1->edge]] -= 1; - degree[edges.child[ex1->edge]] -= 1; - } - for (ex1 = edges_in_head; ex1 != NULL; ex1 = ex1->next) { - degree[edges.parent[ex1->edge]] += 1; - degree[edges.child[ex1->edge]] += 1; - } - - there = forwards ? self->sequence_length : 0; - if (forwards) { - if (tk < M) { - there = TSK_MIN(there, far_side[O[tk]]); - } - if (tj < M) { - there = TSK_MIN(there, near_side[I[tj]]); - } - } else { - if (tk >= 0) { - there = TSK_MAX(there, far_side[O[tk]]); - } - if (tj >= 0) { - there = TSK_MAX(there, near_side[I[tj]]); - } - } - - // iterate over pairs of out and in: (ex1, ex2, in) - for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { - if (!ex1->extended) { - e1 = ex1->edge; - for (ex2 = edges_out_head; ex2 != NULL; ex2 = ex2->next) { - if (!ex2->extended) { - e2 = ex2->edge; - if ((edges.parent[e1] == edges.child[e2]) - && (degree[edges.child[e2]] == 0)) { - for (ex_in = edges_in_head; ex_in != NULL; - ex_in = ex_in->next) { - e_in = ex_in->edge; - if (sign * far_side[e_in] > sign * here) { - if ((edges.child[e1] == edges.child[e_in]) - && (edges.parent[e2] == edges.parent[e_in])) { - ex1->extended = true; - ex2->extended = true; - ex_in->extended = true; - far_side[e1] = there; - far_side[e2] = there; - near_side[e_in] = there; - degree[edges.parent[e1]] += 2; - } - } - } - } - } - } - } - } - // cleanup at end of loop - here = there; - } - - // done! write out new edge tables - for (tj = 0; tj < (tsk_id_t) edges.num_rows; tj++) { - tsk_edge_table_get_row_unsafe(&edges, tj, &edge); - if (edge.left < edge.right) { - ret_id = tsk_edge_table_add_row(&self->edges, edge.left, edge.right, - edge.parent, edge.child, edge.metadata, edge.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; - } - } - } - ret = tsk_table_collection_build_index(self, 0); - if (ret != 0) { - goto out; - } - -out: - tsk_blkalloc_free(&edge_list_heap); - tsk_safe_free(degree); - tsk_edge_table_free(&edges); - return ret; -} - -int TSK_WARN_UNUSED -tsk_table_collection_extend_edges( - tsk_table_collection_t *self, int max_iter, tsk_flags_t TSK_UNUSED(options)) -{ - int ret = 0; - tsk_size_t last_num_edges; - last_num_edges = self->edges.num_rows; - for (int j = 0; j <= max_iter; j++) { - ret = do_extend(self, TSK_DIR_FORWARD); - if (ret != 0) { - goto out; - } - ret = do_extend(self, TSK_DIR_REVERSE); - if (ret != 0) { - goto out; - } - if (self->edges.num_rows == last_num_edges) { - break; - } else { - last_num_edges = self->edges.num_rows; - } - } - -out: - return ret; -} - int TSK_WARN_UNUSED tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, tsk_size_t num_nodes, tsk_flags_t options) diff --git a/c/tskit/tables.h b/c/tskit/tables.h index d815a322b7..048613de0f 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -4359,38 +4359,6 @@ Options can be specified by providing one or more of the following bitwise int tsk_table_collection_simplify(tsk_table_collection_t *self, const tsk_id_t *samples, tsk_size_t num_samples, tsk_flags_t options, tsk_id_t *node_map); -/** -@brief Extends edges - -Modifies the tables in place so that the span covered by ancestral nodes -is "extended" to regions of the genome according to the following rule: -If an ancestral segment corresponding to node `n` has parent `p` and -child `c` on some portion of the genome, and on an adjacent segment of -genome `p` is the immediate parent of `c`, then `n` is inserted into the -edge from `p` to `c`. This involves extending the span of the edges -from `p` to `n` and `n` to `c` and reducing the span of the edge from -`p` to `c`. Since the latter edge may be removed entirely, this process -reduces (or at least does not increase) the number of edges in the tree -sequence. - -The method works by iterating over the genome to look for edges that can -be extended in this way; the maximum number of such iterations is -controlled by ``max_iter``. - - -@rst - -**Options**: None currently defined. -@endrst - -@param self A pointer to a tsk_table_collection_t object. -@param max_iter The maximum number of iterations over the tree sequence. -@param options Bitwise option flags. (UNUSED) -@return Return 0 on success or a negative value on failure. -*/ -int tsk_table_collection_extend_edges( - tsk_table_collection_t *self, int max_iter, tsk_flags_t options); - /** @brief Subsets and reorders a table collection according to an array of nodes. diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 4ac90cca74..0cf002d132 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6887,20 +6887,20 @@ do_extend(tsk_treeseq_t *self, int direction) tsk_edge_table_t edges; tsk_edge_t edge; double there, left, right; - tsk_id_t sign_int; bool forwards = (direction == TSK_DIR_FORWARD); tsk_tree_position_t tree_pos; bool valid; + tsk_table_collection_t *tables = self->tables; // need to do this so tsk_safe_free works on it if it's not initialized memset(&edges, 0, sizeof(edges)); - degree = tsk_malloc(self->tables->nodes.num_rows * sizeof(*degree)); + degree = tsk_malloc(tables->nodes.num_rows * sizeof(*degree)); if (degree == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - tsk_memset(degree, 0x00, self->tables->nodes.num_rows * sizeof(*degree)); + tsk_memset(degree, 0x00, tables->nodes.num_rows * sizeof(*degree)); ret = tsk_blkalloc_init(&edge_list_heap, 8192); if (ret != 0) { @@ -6912,21 +6912,16 @@ do_extend(tsk_treeseq_t *self, int direction) goto out; } - ret = tsk_edge_table_copy(&self->tables->edges, &edges, 0); - if (ret != 0) { - goto out; - } - ret = tsk_edge_table_clear(&self->tables->edges); + ret = tsk_edge_table_copy(&tables->edges, &edges, 0); if (ret != 0) { goto out; } + // can't clear the edge table as that removes the indices also if (forwards) { - sign_int = 1; near_side = edges.left; far_side = edges.right; } else { - sign_int = -1; near_side = edges.right; far_side = edges.left; } @@ -6950,7 +6945,7 @@ do_extend(tsk_treeseq_t *self, int direction) remove_unextended(&edges_in_head, &edges_in_tail); remove_unextended(&edges_out_head, &edges_out_tail); - for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += sign_int) { + for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += direction) { e = tree_pos.out.order[tj]; // add edge to pending_out ret = extend_edges_append_entry( @@ -6960,7 +6955,7 @@ do_extend(tsk_treeseq_t *self, int direction) goto out; } } - for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += sign_int) { + for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += direction) { e = tree_pos.in.order[tj]; // add edge to pending_in ret = extend_edges_append_entry( @@ -7017,11 +7012,16 @@ do_extend(tsk_treeseq_t *self, int direction) } } + ret = tsk_edge_table_clear(&tables->edges); + if (ret != 0) { + goto out; + } + // done! write out new edge tables for (tj = 0; tj < (tsk_id_t) edges.num_rows; tj++) { tsk_edge_table_get_row(&edges, tj, &edge); if (edge.left < edge.right) { - ret_id = tsk_edge_table_add_row(&self->tables->edges, edge.left, edge.right, + ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, edge.parent, edge.child, edge.metadata, edge.metadata_length); if (ret_id < 0) { ret = (int) ret_id; @@ -7029,7 +7029,9 @@ do_extend(tsk_treeseq_t *self, int direction) } } } - ret = tsk_table_collection_build_index(self->tables, 0); + // TODO: do we need to do anything else to make the tree sequence up-to-date with the + // tables? (eg re-initialize?) + ret = tsk_table_collection_build_index(tables, 0); if (ret != 0) { goto out; } @@ -7043,28 +7045,45 @@ do_extend(tsk_treeseq_t *self, int direction) } int TSK_WARN_UNUSED -tsk_tree_sequence_extend_edges( - tsk_treeseq_t *self, int max_iter, tsk_flags_t TSK_UNUSED(options)) +tsk_treeseq_extend_edges(tsk_treeseq_t *self, int max_iter, + tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) { int ret = 0; - tsk_size_t last_num_edges; - last_num_edges = self->tables->edges.num_rows; - for (int j = 0; j <= max_iter; j++) { - ret = do_extend(self, TSK_DIR_FORWARD); + tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); + tsk_size_t last_num_edges = self->tables->edges.num_rows; + + memset(output, 0, sizeof(*output)); + ret = tsk_treeseq_copy_tables(self, tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init( + output, tables, TSK_TAKE_OWNERSHIP & TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { + goto out; + } + tables = NULL; // should this be before the error check? + + for (int j = 0; j < max_iter; j++) { + ret = do_extend(output, TSK_DIR_FORWARD); if (ret != 0) { goto out; } - ret = do_extend(self, TSK_DIR_REVERSE); + ret = do_extend(output, TSK_DIR_REVERSE); if (ret != 0) { goto out; } - if (self->tables->edges.num_rows == last_num_edges) { + if (output->tables->edges.num_rows == last_num_edges) { break; } else { - last_num_edges = self->tables->edges.num_rows; + last_num_edges = output->tables->edges.num_rows; } } out: + if (tables != NULL) { + tsk_table_collection_free(tables); + tsk_safe_free(tables); + } return ret; } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index e512f8d67a..08b595f684 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -891,6 +891,38 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, tsk_size_t num_samples, tsk_flags_t options, tsk_treeseq_t *output, tsk_id_t *node_map); +/** +@brief Extends edges + +Returns a modified tree sequence in which the span covered by ancestral nodes +is "extended" to regions of the genome according to the following rule: +If an ancestral segment corresponding to node `n` has parent `p` and +child `c` on some portion of the genome, and on an adjacent segment of +genome `p` is the immediate parent of `c`, then `n` is inserted into the +edge from `p` to `c`. This involves extending the span of the edges +from `p` to `n` and `n` to `c` and reducing the span of the edge from +`p` to `c`. Since the latter edge may be removed entirely, this process +reduces (or at least does not increase) the number of edges in the tree +sequence. + +The method works by iterating over the genome to look for edges that can +be extended in this way; the maximum number of such iterations is +controlled by ``max_iter``. + + +@rst + +**Options**: None currently defined. +@endrst + +@param self A pointer to a tsk_treeseq_t object. +@param max_iter The maximum number of iterations over the tree sequence. +@param options Bitwise option flags. (UNUSED) +@return Return 0 on success or a negative value on failure. +*/ +int tsk_treeseq_extend_edges( + tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output); + /** @} */ int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 55dd1c3253..60dd00ec2e 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6989,32 +6989,6 @@ TableCollection_link_ancestors(TableCollection *self, PyObject *args, PyObject * return ret; } -static PyObject * -TableCollection_extend_edges(TableCollection *self, PyObject *args, PyObject *kwds) -{ - int err; - PyObject *ret = NULL; - int max_iter; - tsk_flags_t options = 0; - static char *kwlist[] = { "max_iter", NULL }; - - if (TableCollection_check_state(self) != 0) { - goto out; - } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "i", kwlist, &max_iter)) { - goto out; - } - - err = tsk_table_collection_extend_edges(self->tables, max_iter, options); - if (err != 0) { - handle_library_error(err); - goto out; - } - ret = Py_BuildValue(""); -out: - return ret; -} - static PyObject * TableCollection_subset(TableCollection *self, PyObject *args, PyObject *kwds) { @@ -7816,10 +7790,6 @@ static PyMethodDef TableCollection_methods[] = { .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns an edge table linking samples to a set of specified ancestors." }, - { .ml_name = "extend_edges", - .ml_meth = (PyCFunction) TableCollection_extend_edges, - .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc = "Extends edges TODO DOCUMENT." }, { .ml_name = "subset", .ml_meth = (PyCFunction) TableCollection_subset, .ml_flags = METH_VARARGS | METH_KEYWORDS, @@ -8968,6 +8938,46 @@ TreeSequence_mean_descendants(TreeSequence *self, PyObject *args, PyObject *kwds return ret; } +static PyObject * +TreeSequence_extend_edges(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + int max_iter; + tsk_flags_t options = 0; + static char *kwlist[] = { "max_iter", NULL }; + TreeSequence *output = NULL; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "i", kwlist, &max_iter)) { + goto out; + } + + output = (TreeSequence *) _PyObject_New((PyTypeObject *) &TreeSequenceType); + if (output == NULL) { + goto out; + } + output->tree_sequence = PyMem_Malloc(sizeof(*output->tree_sequence)); + if (output->tree_sequence == NULL) { + PyErr_NoMemory(); + goto out; + } + + err = tsk_treeseq_extend_edges( + self->tree_sequence, max_iter, options, output->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) output; + output = NULL; +out: + Py_XDECREF(output); + return ret; +} + /* Error value returned from summary_func callback if an error occured. * This is chosen so that it is not a valid tskit error code and so can * never be mistaken for a different error */ @@ -10555,6 +10565,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_split_edges, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns a copy of this tree sequence edges split at time t" }, + { .ml_name = "extend_edges", + .ml_meth = (PyCFunction) TreeSequence_extend_edges, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Extends edges, creating unary nodes." }, { .ml_name = "has_reference_sequence", .ml_meth = (PyCFunction) TreeSequence_has_reference_sequence, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index da349eb3e6..0634f2556b 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -8749,16 +8749,10 @@ def verify_extend_edges(self, ts, max_iter=10): # finally, compare C version to python version py_et = self.py_extend_edges(ts, max_iter=max_iter).dump_tables() et = ets.dump_tables() - # py_et.tree_sequence().dump("py.trees") - # ets.dump("c.trees") + py_et.tree_sequence().dump("py.trees") + ets.dump("c.trees") et.assert_equals(py_et) - def test_extend_edges_errors(self): - t = msprime.simulate(5, random_seed=126).dump_tables() - t.drop_index() - with pytest.raises(_tskit.LibraryError, match="TSK_ERR_TABLES_NOT_INDEXED"): - t.extend_edges() - def test_runs(self): ts = msprime.simulate(5, random_seed=126) self.verify_extend_edges(ts) @@ -8826,8 +8820,7 @@ def test_simple_ex(self): edges.add_row(parent=p, child=c, left=l, right=r) tables.sort() ts = tables.tree_sequence() - tables.extend_edges() - ets = tables.tree_sequence() + ets = ts.extend_edges() assert ts.num_edges == 18 assert ets.num_edges == 13 for t in ets.trees(): @@ -8844,14 +8837,14 @@ def test_wright_fisher_trees(self): tables.sort() tables.simplify() ts = tables.tree_sequence() - # self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts, max_iter=1) self.verify_extend_edges(ts) def test_wright_fisher_trees_unsimplified(self): tables = wf.wf_sim(N=6, ngens=22, deep_history=False, seed=4) tables.sort() ts = tables.tree_sequence() - # self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts, max_iter=1) self.verify_extend_edges(ts) def test_wright_fisher_trees_with_history(self): @@ -8859,7 +8852,7 @@ def test_wright_fisher_trees_with_history(self): tables.sort() tables.simplify() ts = tables.tree_sequence() - # self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts, max_iter=1) self.verify_extend_edges(ts) # def test_bigger_wright_fisher(self): diff --git a/python/tskit/tables.py b/python/tskit/tables.py index a0cf6d16a3..cb2ff7d01f 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4037,19 +4037,6 @@ def drop_index(self): """ self._ll_tables.drop_index() - def extend_edges(self, max_iter=10): - """ - Modifies the tables in place by applying the operation described - in :meth:`TreeSequence.extend_edges`. - - :param int max_iter: The maximum number of iterations over the tree - sequence. Defaults to 10. - """ - max_iter = int(max_iter) - if max_iter <= 0: - raise ValueError("max_iter must be a positive integer.") - self._ll_tables.extend_edges(max_iter) - def subset( self, nodes, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 836cbafee5..03b0f069c6 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6938,9 +6938,11 @@ def extend_edges(self, max_iter=10): :return: A new tree sequence with unary nodes extended. :rtype: tskit.TreeSequence """ - t = self.dump_tables() - t.extend_edges(max_iter=max_iter) - return t.tree_sequence() + max_iter = int(max_iter) + if max_iter <= 0: + raise ValueError("max_iter must be a positive integer.") + ll_ts = self._ll_tree_sequence.extend_edges(max_iter) + return TreeSequence(ll_ts) def subset( self, From c7e66215b4abe2a879a3a00501d390dd194049c5 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 17 Aug 2023 12:23:40 -0700 Subject: [PATCH 81/84] docfix --- c/tskit/trees.h | 1 + 1 file changed, 1 insertion(+) diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 08b595f684..379cede090 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -918,6 +918,7 @@ controlled by ``max_iter``. @param self A pointer to a tsk_treeseq_t object. @param max_iter The maximum number of iterations over the tree sequence. @param options Bitwise option flags. (UNUSED) +@param output A pointer to an uninitialised tsk_treeseq_t object. @return Return 0 on success or a negative value on failure. */ int tsk_treeseq_extend_edges( From 8a1335ff8d596aed306690d9e630931cf4fbd050 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 18 Aug 2023 13:51:19 +0100 Subject: [PATCH 82/84] Refactor implementation and tests --- c/tests/test_trees.c | 61 +++-- c/tskit/trees.c | 155 ++++++------ c/tskit/trees.h | 2 +- python/tests/test_extend_edges.py | 379 ++++++++++++++++++++++++++++++ python/tests/test_topology.py | 347 --------------------------- 5 files changed, 495 insertions(+), 449 deletions(-) create mode 100644 python/tests/test_extend_edges.py diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index a41efa6884..5acf465db6 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -8217,41 +8217,49 @@ test_extend_edges_simple(void) { int ret; tsk_treeseq_t ts, ets; - tsk_table_collection_t tables; - const char *nodes_ex = "1 0 -1 -1\n" "1 0 -1 -1\n" "0 2.0 -1 -1\n"; const char *edges_ex = "0 10 2 0\n" "0 10 2 1\n"; - ret = tsk_table_collection_init(&tables, 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - tables.sequence_length = 10; - - parse_nodes(nodes_ex, &tables.nodes); - CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 3); - parse_edges(edges_ex, &tables.edges); - CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 2); - - ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); - CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_from_text(&ts, 10, nodes_ex, edges_ex, NULL, NULL, NULL, NULL, NULL, 0); ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, ets.tables, 0)); - tsk_table_collection_free(&tables); tsk_treeseq_free(&ts); tsk_treeseq_free(&ets); } static void -test_extend_edges(void) +assert_equal_except_edges(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2) { + tsk_table_collection_t t1, t2; int ret; - tsk_treeseq_t ts, ets; + + ret = tsk_table_collection_copy(ts1->tables, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_table_collection_copy(ts2->tables, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_edge_table_clear(&t1.edges); + tsk_edge_table_clear(&t2.edges); + + CU_ASSERT_TRUE(tsk_table_collection_equals(&t1, &t2, 0)); + + tsk_table_collection_free(&t1); + tsk_table_collection_free(&t2); +} + +static void +test_extend_edges(void) +{ + int ret, max_iter; tsk_table_collection_t tables; + tsk_treeseq_t ts, ets; /* 7 and 8 should be extended to the whole sequence 6 6 6 6 @@ -8282,9 +8290,9 @@ test_extend_edges(void) "5 10 5 3\n" "0 2 6 4\n" "5 10 6 4\n" + "2 5 6 3\n" "0 2 6 5\n" "7 10 6 5\n" - "2 5 6 3\n" "2 5 6 7\n" "5 7 6 8\n" "2 5 7 2\n" @@ -8292,28 +8300,41 @@ test_extend_edges(void) "5 7 8 1\n" "5 7 8 5\n"; + /* Doing this rather than tsk_treeseq_from_text because the edges are unsorted */ ret = tsk_table_collection_init(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tables.sequence_length = 10; - parse_nodes(nodes_ex, &tables.nodes); CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); parse_edges(edges_ex, &tables.edges); CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 18); ret = tsk_table_collection_sort(&tables, NULL, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_extend_edges(&ts, 0, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE_FATAL(tsk_table_collection_equals(ts.tables, ets.tables, 0)); + /* tsk_treeseq_print_state(&ets, stdout); */ + tsk_treeseq_free(&ets); + + for (max_iter = 1; max_iter < 10; max_iter++) { + ret = tsk_treeseq_extend_edges(&ts, max_iter, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_equal_except_edges(&ts, &ets); + CU_ASSERT_TRUE(ets.tables->edges.num_rows >= 13); + tsk_treeseq_free(&ets); + } + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(ets.tables->nodes.num_rows, 9); CU_ASSERT_EQUAL_FATAL(ets.tables->edges.num_rows, 13); + tsk_treeseq_free(&ets); tsk_table_collection_free(&tables); tsk_treeseq_free(&ts); - tsk_treeseq_free(&ets); } static void diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 0cf002d132..04936afb71 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6869,61 +6869,57 @@ remove_unextended(edge_list_t **head, edge_list_t **tail) } static int -do_extend(tsk_treeseq_t *self, int direction) +tsk_treeseq_extend_edges_iter( + const tsk_treeseq_t *self, int direction, tsk_edge_table_t *edges) { // Note: this modifies the edge table, but it does this by (a) removing // some edges, and (b) extending left/right endpoints of others, // while keeping order the same, and so this maintains sortedness // (so, there is no need to sort afterwards). int ret = 0; - tsk_id_t *degree; - tsk_id_t tj, ret_id; + tsk_id_t tj; tsk_id_t e, e1, e2, e_in; - double *near_side, *far_side; tsk_blkalloc_t edge_list_heap; + double *near_side, *far_side; edge_list_t *edges_in_head, *edges_in_tail; edge_list_t *edges_out_head, *edges_out_tail; edge_list_t *ex1, *ex2, *ex_in; - tsk_edge_table_t edges; - tsk_edge_t edge; double there, left, right; bool forwards = (direction == TSK_DIR_FORWARD); tsk_tree_position_t tree_pos; bool valid; - tsk_table_collection_t *tables = self->tables; + const tsk_table_collection_t *tables = self->tables; + const tsk_size_t num_nodes = tables->nodes.num_rows; + const tsk_size_t num_edges = tables->edges.num_rows; + tsk_id_t *degree = tsk_calloc(num_nodes, sizeof(*degree)); + tsk_bool_t *keep = tsk_calloc(num_edges, sizeof(*keep)); - // need to do this so tsk_safe_free works on it if it's not initialized - memset(&edges, 0, sizeof(edges)); + memset(&edge_list_heap, 0, sizeof(edge_list_heap)); + memset(&tree_pos, 0, sizeof(tree_pos)); - degree = tsk_malloc(tables->nodes.num_rows * sizeof(*degree)); - if (degree == NULL) { + if (keep == NULL || degree == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - tsk_memset(degree, 0x00, tables->nodes.num_rows * sizeof(*degree)); - ret = tsk_blkalloc_init(&edge_list_heap, 8192); if (ret != 0) { goto out; } - ret = tsk_tree_position_init(&tree_pos, self, 0); if (ret != 0) { goto out; } - - ret = tsk_edge_table_copy(&tables->edges, &edges, 0); + ret = tsk_edge_table_copy(&tables->edges, edges, TSK_NO_INIT); if (ret != 0) { goto out; } - // can't clear the edge table as that removes the indices also if (forwards) { - near_side = edges.left; - far_side = edges.right; + near_side = edges->left; + far_side = edges->right; } else { - near_side = edges.right; - far_side = edges.left; + near_side = edges->right; + far_side = edges->left; } edges_in_head = NULL; edges_in_tail = NULL; @@ -6966,12 +6962,12 @@ do_extend(tsk_treeseq_t *self, int direction) } } for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { - degree[edges.parent[ex1->edge]] -= 1; - degree[edges.child[ex1->edge]] -= 1; + degree[edges->parent[ex1->edge]] -= 1; + degree[edges->child[ex1->edge]] -= 1; } for (ex1 = edges_in_head; ex1 != NULL; ex1 = ex1->next) { - degree[edges.parent[ex1->edge]] += 1; - degree[edges.child[ex1->edge]] += 1; + degree[edges->parent[ex1->edge]] += 1; + degree[edges->child[ex1->edge]] += 1; } // iterate over pairs of out and in: (ex1, ex2, in) @@ -6981,22 +6977,22 @@ do_extend(tsk_treeseq_t *self, int direction) for (ex2 = edges_out_head; ex2 != NULL; ex2 = ex2->next) { if (!ex2->extended) { e2 = ex2->edge; - if ((edges.parent[e1] == edges.child[e2]) - && (degree[edges.child[e2]] == 0)) { + if ((edges->parent[e1] == edges->child[e2]) + && (degree[edges->child[e2]] == 0)) { for (ex_in = edges_in_head; ex_in != NULL; ex_in = ex_in->next) { e_in = ex_in->edge; - if ((edges.left[e_in] < right) - && (edges.right[e_in] > left)) { - if ((edges.child[e1] == edges.child[e_in]) - && (edges.parent[e2] == edges.parent[e_in])) { + if ((edges->left[e_in] < right) + && (edges->right[e_in] > left)) { + if ((edges->child[e1] == edges->child[e_in]) + && (edges->parent[e2] == edges->parent[e_in])) { ex1->extended = true; ex2->extended = true; ex_in->extended = true; far_side[e1] = there; far_side[e2] = there; near_side[e_in] = there; - degree[edges.parent[e1]] += 2; + degree[edges->parent[e1]] += 2; } } } @@ -7012,78 +7008,75 @@ do_extend(tsk_treeseq_t *self, int direction) } } - ret = tsk_edge_table_clear(&tables->edges); - if (ret != 0) { - goto out; - } - - // done! write out new edge tables - for (tj = 0; tj < (tsk_id_t) edges.num_rows; tj++) { - tsk_edge_table_get_row(&edges, tj, &edge); - if (edge.left < edge.right) { - ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, - edge.parent, edge.child, edge.metadata, edge.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; - } - } + for (e = 0; e < (tsk_id_t) num_edges; e++) { + keep[e] = edges->left[e] < edges->right[e]; } - // TODO: do we need to do anything else to make the tree sequence up-to-date with the - // tables? (eg re-initialize?) - ret = tsk_table_collection_build_index(tables, 0); - if (ret != 0) { - goto out; - } - + ret = tsk_edge_table_keep_rows(edges, keep, 0, NULL); out: tsk_blkalloc_free(&edge_list_heap); - tsk_safe_free(degree); - tsk_edge_table_free(&edges); tsk_tree_position_free(&tree_pos); + tsk_safe_free(degree); + tsk_safe_free(keep); return ret; } int TSK_WARN_UNUSED -tsk_treeseq_extend_edges(tsk_treeseq_t *self, int max_iter, +tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) { int ret = 0; - tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); - tsk_size_t last_num_edges = self->tables->edges.num_rows; - - memset(output, 0, sizeof(*output)); - ret = tsk_treeseq_copy_tables(self, tables, 0); + tsk_table_collection_t tables; + tsk_treeseq_t ts; + int iter, j; + tsk_size_t last_num_edges; + const int direction[] = { TSK_DIR_FORWARD, TSK_DIR_REVERSE }; + + tsk_memset(&tables, 0, sizeof(tables)); + tsk_memset(&ts, 0, sizeof(ts)); + tsk_memset(output, 0, sizeof(*output)); + + /* Note: there is a fair bit of copying of table data in this implementation + * currently, as we create a new tree sequence for each iteration, which + * takes a full copy of the input tables. We could streamline this by + * adding a flag to treeseq_init which says "steal a reference to these + * tables and *don't* free them at the end". Then, we would only need + * one copy of the full tables, and could pass in a standalone edge + * table to use for in-place updating. + */ + ret = tsk_table_collection_copy(self->tables, &tables, 0); if (ret != 0) { goto out; } - ret = tsk_treeseq_init( - output, tables, TSK_TAKE_OWNERSHIP & TSK_TS_INIT_BUILD_INDEXES); + ret = tsk_treeseq_init(&ts, &tables, 0); if (ret != 0) { goto out; } - tables = NULL; // should this be before the error check? - for (int j = 0; j < max_iter; j++) { - ret = do_extend(output, TSK_DIR_FORWARD); - if (ret != 0) { - goto out; - } - ret = do_extend(output, TSK_DIR_REVERSE); - if (ret != 0) { - goto out; + last_num_edges = tsk_treeseq_get_num_edges(&ts); + for (iter = 0; iter < max_iter; iter++) { + for (j = 0; j < 2; j++) { + ret = tsk_treeseq_extend_edges_iter(&ts, direction[j], &tables.edges); + if (ret != 0) { + goto out; + } + /* We're done with the current ts now */ + tsk_treeseq_free(&ts); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { + goto out; + } } - if (output->tables->edges.num_rows == last_num_edges) { + if (last_num_edges == tsk_treeseq_get_num_edges(&ts)) { break; - } else { - last_num_edges = output->tables->edges.num_rows; } + last_num_edges = tsk_treeseq_get_num_edges(&ts); } + /* Hand ownership of the tree sequence to the calling code */ + tsk_memcpy(output, &ts, sizeof(ts)); + tsk_memset(&ts, 0, sizeof(*output)); out: - if (tables != NULL) { - tsk_table_collection_free(tables); - tsk_safe_free(tables); - } + tsk_treeseq_free(&ts); + tsk_table_collection_free(&tables); return ret; } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 379cede090..738d399699 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -922,7 +922,7 @@ controlled by ``max_iter``. @return Return 0 on success or a negative value on failure. */ int tsk_treeseq_extend_edges( - tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output); + const tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output); /** @} */ diff --git a/python/tests/test_extend_edges.py b/python/tests/test_extend_edges.py new file mode 100644 index 0000000000..ffca67b2e7 --- /dev/null +++ b/python/tests/test_extend_edges.py @@ -0,0 +1,379 @@ +import msprime +import numpy as np +import pytest + +import tests.test_wright_fisher as wf +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +def extend_edges(ts, max_iter=10): + tables = ts.dump_tables() + + last_num_edges = ts.num_edges + for _ in range(max_iter): + for forwards in [True, False]: + edges = _extend(ts, forwards=forwards) + tables.edges.replace_with(edges) + tables.build_index() + ts = tables.tree_sequence() + if ts.num_edges == last_num_edges: + break + else: + last_num_edges = ts.num_edges + return ts + + +def _extend(ts, forwards=True): + degree = np.full(ts.num_nodes, 0) + keep = np.full(ts.num_edges, True, dtype=bool) + + edges = ts.tables.edges.copy() + + # "here" will be left if fowards else right; + # and "there" is the other + new_left = edges.left.copy() + new_right = edges.right.copy() + if forwards: + direction = 1 + # in C we can just modify these in place, but in + # python they are (silently) immutable + new_here = new_left + new_there = new_right + else: + direction = -1 + new_here = new_right + new_there = new_left + edges_out = [] + edges_in = [] + + tree_pos = tsutil.TreePosition(ts) + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + while valid: + left, right = tree_pos.interval + there = right if forwards else left + + # clear out non-extended or postponed edges + edges_out = [[e, False] for e, x in edges_out if x] + edges_in = [[e, False] for e, x in edges_in if x] + + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, direction): + e = tree_pos.out_range.order[j] + edges_out.append([e, False]) + + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, direction): + e = tree_pos.in_range.order[j] + edges_in.append([e, False]) + + for e, _ in edges_out: + degree[edges.parent[e]] -= 1 + degree[edges.child[e]] -= 1 + for e, _ in edges_in: + degree[edges.parent[e]] += 1 + degree[edges.child[e]] += 1 + + assert np.all(degree >= 0) + for ex1 in edges_out: + if not ex1[1]: + e1 = ex1[0] + for ex2 in edges_out: + if not ex2[1]: + # the intermediate node should not be present in + # the new tree + e2 = ex2[0] + if (edges.parent[e1] == edges.child[e2]) and ( + degree[edges.child[e2]] == 0 + ): + for ex_in in edges_in: + e_in = ex_in[0] + # we might have passed the interval that a + # postponed edge in covers, in which case + # we should skip it + if new_left[e_in] < right and new_right[e_in] > left: + if ( + edges.child[e1] == edges.child[e_in] + and edges.parent[e2] == edges.parent[e_in] + ): + ex1[1] = True + ex2[1] = True + ex_in[1] = True + new_there[e1] = there + new_there[e2] = there + new_here[e_in] = there + # amend degree: the intermediate + # node has 2 edges instead of 0 + degree[edges.parent[e1]] += 2 + # end of loop, next tree + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + + for j in range(edges.num_rows): + left = new_left[j] + right = new_right[j] + if left < right: + edges[j] = edges[j].replace(left=left, right=right) + else: + keep[j] = False + edges.keep_rows(keep) + return edges + + +class TestExtendEdges: + """ + Test the 'extend edges' method + """ + + def verify_extend_edges(self, ts, max_iter=10): + # This can still fail for various weird examples: + # for instance, if adjacent trees have + # a <- b <- c <- d and a <- d (where say b was + # inserted in an earlier pass), then b and c + # won't be extended + + ets = ts.extend_edges(max_iter=max_iter) + assert ts.num_samples == ets.num_samples + assert ts.num_nodes == ets.num_nodes + assert ts.num_edges >= ets.num_edges + t = ts.simplify().tables + et = ets.simplify().tables + t.assert_equals(et, ignore_provenance=True) + old_edges = {} + for e in ts.edges(): + k = (e.parent, e.child) + if k not in old_edges: + old_edges[k] = [] + old_edges[k].append((e.left, e.right)) + + for e in ets.edges(): + # e should be in old_edges, + # but with modified limits: + # USUALLY overlapping limits, but + # not necessarily after more than one pass + k = (e.parent, e.child) + assert k in old_edges + if max_iter == 1: + overlaps = False + for left, right in old_edges[k]: + if (left <= e.right) and (right >= e.left): + overlaps = True + assert overlaps + + if max_iter > 1: + chains = [] + for _, tt, ett in ts.coiterate(ets): + this_chains = [] + for a in tt.nodes(): + assert a in ett.nodes() + b = tt.parent(a) + if b != tskit.NULL: + c = tt.parent(b) + if c != tskit.NULL: + this_chains.append((a, b, c)) + assert b in ett.nodes() + # the relationship a <- b should still be in the tree + p = a + while p != tskit.NULL and p != b: + p = ett.parent(p) + assert p == b + chains.append(this_chains) + + extended_ac = {} + not_extended_ac = {} + extended_ab = {} + not_extended_ab = {} + for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): + for j in (k - 1, k + 1): + if j < 0 or j >= len(chains): + continue + else: + this_chains = chains[j] + for a, b, c in this_chains: + if ( + a in tt.nodes() + and tt.parent(a) == c + and b not in tt.nodes() + ): + # the relationship a <- b <- c should still be in the tree, + # although maybe they aren't direct parent-offspring + # UNLESS we've got an ambiguous case, where on the opposite + # side of the interval a chain a <- b' <- c got extended + # into the region OR b got inserted into another chain + assert a in ett.nodes() + assert c in ett.nodes() + if b not in ett.nodes(): + if (a, c) not in not_extended_ac: + not_extended_ac[(a, c)] = [] + not_extended_ac[(a, c)].append(interval) + else: + if (a, c) not in extended_ac: + extended_ac[(a, c)] = [] + extended_ac[(a, c)].append(interval) + p = a + while p != tskit.NULL and p != b: + p = ett.parent(p) + if p != b: + if (a, b) not in not_extended_ab: + not_extended_ab[(a, b)] = [] + not_extended_ab[(a, b)].append(interval) + else: + if (a, b) not in extended_ab: + extended_ab[(a, b)] = [] + extended_ab[(a, b)].append(interval) + while p != tskit.NULL and p != c: + p = ett.parent(p) + assert p == c + for a, c in not_extended_ac: + # check that a <- ... <- c has been extended somewhere + # although not necessarily from an adjacent segment + assert (a, c) in extended_ac + for interval in not_extended_ac[(a, c)]: + ett = ets.at(interval.left) + assert ett.parent(a) != c + for k in not_extended_ab: + assert k in extended_ab + for interval in not_extended_ab[k]: + assert interval in extended_ab[k] + + # finally, compare C version to python version + py_et = extend_edges(ts, max_iter=max_iter).dump_tables() + et = ets.dump_tables() + et.assert_equals(py_et) + + def test_runs(self): + ts = msprime.simulate(5, random_seed=126) + self.verify_extend_edges(ts) + + def test_max_iter(self): + ts = msprime.simulate(5, random_seed=126) + with pytest.raises(ValueError, match="max_iter"): + ets = ts.extend_edges(max_iter=0) + ets = ts.extend_edges(max_iter=1) + et = ets.extend_edges(max_iter=1).dump_tables() + eet = ets.extend_edges(max_iter=2).dump_tables() + eet.assert_equals(et) + + def test_simple_ex(self): + # An example where you need to go forwards *and* backwards: + # 7 and 8 should be extended to the whole sequence + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # | | 7 | | 8 | | + # | | ++-+ | | +-++ | | + # 4 5 4 | | 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + # + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 1.0, + 5: 1.0, + 6: 3.0, + 7: 2.0, + 8: 2.0, + } + # (p, c, l, r) + edge_stuff = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (6, 4, 0, 2), + (6, 4, 5, 10), + (6, 5, 0, 2), + (6, 5, 7, 10), + (6, 3, 2, 5), + (6, 7, 2, 5), + (6, 8, 5, 7), + (7, 2, 2, 5), + (7, 4, 2, 5), + (8, 1, 5, 7), + (8, 5, 5, 7), + ] + tables = tskit.TableCollection(sequence_length=10) + nodes = tables.nodes + for n, t in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n < 4 else 0 + nodes.add_row(time=t, flags=flags) + edges = tables.edges + for p, c, l, r in edge_stuff: + edges.add_row(parent=p, child=c, left=l, right=r) + tables.sort() + ts = tables.tree_sequence() + ets = ts.extend_edges() + assert ts.num_edges == 18 + assert ets.num_edges == 13 + for t in ets.trees(): + assert 7 in t.nodes() + assert 8 in t.nodes() + assert t.parent(4) == 7 + assert t.parent(7) == 6 + assert t.parent(5) == 8 + assert t.parent(8) == 6 + self.verify_extend_edges(ts) + + def test_wright_fisher_trees(self): + tables = wf.wf_sim(N=5, ngens=20, deep_history=False, seed=3) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + def test_wright_fisher_trees_unsimplified(self): + tables = wf.wf_sim(N=6, ngens=22, deep_history=False, seed=4) + tables.sort() + ts = tables.tree_sequence() + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + def test_wright_fisher_trees_with_history(self): + tables = wf.wf_sim(N=8, ngens=15, deep_history=True, seed=5) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + # def test_bigger_wright_fisher(self): + # tables = wf.wf_sim(N=50, ngens=15, deep_history=True, seed=6) + # tables.sort() + # tables.simplify() + # ts = tables.tree_sequence() + # self.verify_extend_edges(ts, max_iter=1) + # self.verify_extend_edges(ts, max_iter=200) + + +class TestExamples: + """ + Compare the ts method with local implementation. + """ + + def check(self, ts): + lib_ts = ts.extend_edges() + py_ts = extend_edges(ts) + lib_ts.tables.assert_equals(py_ts.tables) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_suite_examples_defaults(self, ts): + self.check(ts) + + @pytest.mark.parametrize("n", [3, 4, 5]) + def test_all_trees_ts(self, n): + ts = tsutil.all_trees_ts(n) + self.check(ts) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 0634f2556b..d564ec0590 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -8515,350 +8515,3 @@ def test_is_isolated_bad(self): tree.is_isolated("abc") with pytest.raises(TypeError): tree.is_isolated(1.1) - - -class TestExtendEdges: - """ - Test the 'extend edges' method - """ - - def py_extend_edges(self, ts, max_iter=10): - last_num_edges = ts.num_edges - for _ in range(max_iter): - ts = self._extend(ts, forwards=True) - ts = self._extend(ts, forwards=False) - if ts.num_edges == last_num_edges: - break - else: - last_num_edges = ts.num_edges - return ts - - def _extend(self, ts, forwards=True): - degree = np.full(ts.num_nodes, 0) - - t = ts.dump_tables() - edges = t.edges.copy() - t.edges.clear() - - # "here" will be left if fowards else right; - # and "there" is the other - new_left = edges.left.copy() - new_right = edges.right.copy() - if forwards: - direction = 1 - # in C we can just modify these in place, but in - # python they are (silently) immutable - new_here = new_left - new_there = new_right - else: - direction = -1 - new_here = new_right - new_there = new_left - edges_out = [] - edges_in = [] - - tree_pos = tsutil.TreePosition(ts) - if forwards: - valid = tree_pos.next() - else: - valid = tree_pos.prev() - while valid: - left, right = tree_pos.interval - there = right if forwards else left - - # clear out non-extended or postponed edges - edges_out = [[e, False] for e, x in edges_out if x] - edges_in = [[e, False] for e, x in edges_in if x] - - for j in range( - tree_pos.out_range.start, tree_pos.out_range.stop, direction - ): - e = tree_pos.out_range.order[j] - edges_out.append([e, False]) - - for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, direction): - e = tree_pos.in_range.order[j] - edges_in.append([e, False]) - - for e, _ in edges_out: - degree[edges.parent[e]] -= 1 - degree[edges.child[e]] -= 1 - for e, _ in edges_in: - degree[edges.parent[e]] += 1 - degree[edges.child[e]] += 1 - - assert np.all(degree >= 0) - for ex1 in edges_out: - if not ex1[1]: - e1 = ex1[0] - for ex2 in edges_out: - if not ex2[1]: - # the intermediate node should not be present in - # the new tree - e2 = ex2[0] - if (edges.parent[e1] == edges.child[e2]) and ( - degree[edges.child[e2]] == 0 - ): - for ex_in in edges_in: - e_in = ex_in[0] - # we might have passed the interval that a - # postponed edge in covers, in which case - # we should skip it - if ( - new_left[e_in] < right - and new_right[e_in] > left - ): - if ( - edges.child[e1] == edges.child[e_in] - and edges.parent[e2] == edges.parent[e_in] - ): - ex1[1] = True - ex2[1] = True - ex_in[1] = True - new_there[e1] = there - new_there[e2] = there - new_here[e_in] = there - # amend degree: the intermediate - # node has 2 edges instead of 0 - degree[edges.parent[e1]] += 2 - # end of loop, next tree - if forwards: - valid = tree_pos.next() - else: - valid = tree_pos.prev() - - for j in range(edges.num_rows): - left = new_left[j] - right = new_right[j] - if left < right: - e = edges[j].replace(left=left, right=right) - t.edges.append(e) - t.build_index() - return t.tree_sequence() - - def verify_extend_edges(self, ts, max_iter=10): - # This can still fail for various weird examples: - # for instance, if adjacent trees have - # a <- b <- c <- d and a <- d (where say b was - # inserted in an earlier pass), then b and c - # won't be extended - - ets = ts.extend_edges(max_iter=max_iter) - assert ts.num_samples == ets.num_samples - assert ts.num_nodes == ets.num_nodes - assert ts.num_edges >= ets.num_edges - t = ts.simplify().tables - et = ets.simplify().tables - t.assert_equals(et, ignore_provenance=True) - old_edges = {} - for e in ts.edges(): - k = (e.parent, e.child) - if k not in old_edges: - old_edges[k] = [] - old_edges[k].append((e.left, e.right)) - - for e in ets.edges(): - # e should be in old_edges, - # but with modified limits: - # USUALLY overlapping limits, but - # not necessarily after more than one pass - k = (e.parent, e.child) - assert k in old_edges - if max_iter == 1: - overlaps = False - for (l, r) in old_edges[k]: - if (l <= e.right) and (r >= e.left): - overlaps = True - assert overlaps - - if max_iter > 1: - chains = [] - for _, tt, ett in ts.coiterate(ets): - this_chains = [] - for a in tt.nodes(): - assert a in ett.nodes() - b = tt.parent(a) - if b != tskit.NULL: - c = tt.parent(b) - if c != tskit.NULL: - this_chains.append((a, b, c)) - assert b in ett.nodes() - # the relationship a <- b should still be in the tree - p = a - while p != tskit.NULL and p != b: - p = ett.parent(p) - assert p == b - chains.append(this_chains) - - extended_ac = {} - not_extended_ac = {} - extended_ab = {} - not_extended_ab = {} - for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): - for j in (k - 1, k + 1): - if j < 0 or j >= len(chains): - continue - else: - this_chains = chains[j] - for a, b, c in this_chains: - if ( - a in tt.nodes() - and tt.parent(a) == c - and b not in tt.nodes() - ): - # the relationship a <- b <- c should still be in the tree, - # although maybe they aren't direct parent-offspring - # UNLESS we've got an ambiguous case, where on the opposite - # side of the interval a chain a <- b' <- c got extended - # into the region OR b got inserted into another chain - assert a in ett.nodes() - assert c in ett.nodes() - if b not in ett.nodes(): - if (a, c) not in not_extended_ac: - not_extended_ac[(a, c)] = [] - not_extended_ac[(a, c)].append(interval) - else: - if (a, c) not in extended_ac: - extended_ac[(a, c)] = [] - extended_ac[(a, c)].append(interval) - p = a - while p != tskit.NULL and p != b: - p = ett.parent(p) - if p != b: - if (a, b) not in not_extended_ab: - not_extended_ab[(a, b)] = [] - not_extended_ab[(a, b)].append(interval) - else: - if (a, b) not in extended_ab: - extended_ab[(a, b)] = [] - extended_ab[(a, b)].append(interval) - while p != tskit.NULL and p != c: - p = ett.parent(p) - assert p == c - for a, c in not_extended_ac: - # check that a <- ... <- c has been extended somewhere - # although not necessarily from an adjacent segment - assert (a, c) in extended_ac - for interval in not_extended_ac[(a, c)]: - ett = ets.at(interval.left) - assert ett.parent(a) != c - for k in not_extended_ab: - assert k in extended_ab - for interval in not_extended_ab[k]: - assert interval in extended_ab[k] - # finally, compare C version to python version - py_et = self.py_extend_edges(ts, max_iter=max_iter).dump_tables() - et = ets.dump_tables() - py_et.tree_sequence().dump("py.trees") - ets.dump("c.trees") - et.assert_equals(py_et) - - def test_runs(self): - ts = msprime.simulate(5, random_seed=126) - self.verify_extend_edges(ts) - - def test_max_iter(self): - ts = msprime.simulate(5, random_seed=126) - with pytest.raises(ValueError, match="max_iter"): - ets = ts.extend_edges(max_iter=0) - ets = ts.extend_edges(max_iter=1) - et = ets.extend_edges(max_iter=1).dump_tables() - eet = ets.extend_edges(max_iter=2).dump_tables() - eet.assert_equals(et) - - def test_simple_ex(self): - # An example where you need to go forwards *and* backwards: - # 7 and 8 should be extended to the whole sequence - # - # 6 6 6 6 - # +-+-+ +-+-+ +-+-+ +-+-+ - # | | 7 | | 8 | | - # | | ++-+ | | +-++ | | - # 4 5 4 | | 4 | 5 4 5 - # +++ +++ +++ | | | | +++ +++ +++ - # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - # - node_times = { - 0: 0, - 1: 0, - 2: 0, - 3: 0, - 4: 1.0, - 5: 1.0, - 6: 3.0, - 7: 2.0, - 8: 2.0, - } - # (p, c, l, r) - edge_stuff = [ - (4, 0, 0, 10), - (4, 1, 0, 5), - (4, 1, 7, 10), - (5, 2, 0, 2), - (5, 2, 5, 10), - (5, 3, 0, 2), - (5, 3, 5, 10), - (6, 4, 0, 2), - (6, 4, 5, 10), - (6, 5, 0, 2), - (6, 5, 7, 10), - (6, 3, 2, 5), - (6, 7, 2, 5), - (6, 8, 5, 7), - (7, 2, 2, 5), - (7, 4, 2, 5), - (8, 1, 5, 7), - (8, 5, 5, 7), - ] - tables = tskit.TableCollection(sequence_length=10) - nodes = tables.nodes - for n, t in node_times.items(): - flags = tskit.NODE_IS_SAMPLE if n < 4 else 0 - nodes.add_row(time=t, flags=flags) - edges = tables.edges - for p, c, l, r in edge_stuff: - edges.add_row(parent=p, child=c, left=l, right=r) - tables.sort() - ts = tables.tree_sequence() - ets = ts.extend_edges() - assert ts.num_edges == 18 - assert ets.num_edges == 13 - for t in ets.trees(): - assert 7 in t.nodes() - assert 8 in t.nodes() - assert t.parent(4) == 7 - assert t.parent(7) == 6 - assert t.parent(5) == 8 - assert t.parent(8) == 6 - self.verify_extend_edges(ts) - - def test_wright_fisher_trees(self): - tables = wf.wf_sim(N=5, ngens=20, deep_history=False, seed=3) - tables.sort() - tables.simplify() - ts = tables.tree_sequence() - self.verify_extend_edges(ts, max_iter=1) - self.verify_extend_edges(ts) - - def test_wright_fisher_trees_unsimplified(self): - tables = wf.wf_sim(N=6, ngens=22, deep_history=False, seed=4) - tables.sort() - ts = tables.tree_sequence() - self.verify_extend_edges(ts, max_iter=1) - self.verify_extend_edges(ts) - - def test_wright_fisher_trees_with_history(self): - tables = wf.wf_sim(N=8, ngens=15, deep_history=True, seed=5) - tables.sort() - tables.simplify() - ts = tables.tree_sequence() - self.verify_extend_edges(ts, max_iter=1) - self.verify_extend_edges(ts) - - # def test_bigger_wright_fisher(self): - # tables = wf.wf_sim(N=50, ngens=15, deep_history=True, seed=6) - # tables.sort() - # tables.simplify() - # ts = tables.tree_sequence() - # self.verify_extend_edges(ts, max_iter=1) - # self.verify_extend_edges(ts, max_iter=200) From d506073df4260557e19b2a76b7185b1b62889a96 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 18 Aug 2023 13:56:47 +0100 Subject: [PATCH 83/84] Fixup headers --- c/tskit/tables.h | 8 -------- c/tskit/trees.h | 4 ++-- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 048613de0f..38f3096c9d 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -42,14 +42,6 @@ extern "C" { #include -/****************************************************************************/ -/* Generic definitions */ -/****************************************************************************/ - -// These are also used in trees.h -#define TSK_DIR_FORWARD 1 -#define TSK_DIR_REVERSE -1 - /****************************************************************************/ /* Definitions for the basic objects */ /****************************************************************************/ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 738d399699..2305fb5ae3 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -56,8 +56,8 @@ extern "C" { /* Options for map_mutations */ #define TSK_MM_FIXED_ANCESTRAL_STATE (1 << 0) -/* For the edge diff iterator */ -#define TSK_INCLUDE_TERMINAL (1 << 0) +#define TSK_DIR_FORWARD 1 +#define TSK_DIR_REVERSE -1 /** @defgroup API_FLAGS_TS_INIT_GROUP :c:func:`tsk_treeseq_init` specific flags. From ebb79ac8c594e02d8ff1239e18e2d43c41057f74 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 18 Aug 2023 14:24:43 +0100 Subject: [PATCH 84/84] Add low-level py test --- python/tests/test_lowlevel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index c33f159deb..5e5c5b2272 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1496,6 +1496,13 @@ def test_time_units(self): ts.load_tables(tables) assert ts.get_time_units() == value + def test_extend_edges_bad_args(self): + ts1 = self.get_example_tree_sequence(10) + with pytest.raises(TypeError): + ts1.extend_edges() + with pytest.raises(TypeError, match="as an int"): + ts1.extend_edges("sdf") + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError):