diff --git a/tsconvert/newick.py b/tsconvert/newick.py index 9a933b9..1d54827 100644 --- a/tsconvert/newick.py +++ b/tsconvert/newick.py @@ -43,11 +43,15 @@ def from_ms(string): if tree_lines == "": raise ValueError( "Malformed input: no lines starting with [." - " Make sure you run ms with the -T and -r flags.") + " Make sure you run ms with the -T and -r flags." + ) trees = dendropy.TreeList.get( - data=tree_lines, schema="newick", extract_comment_metadata=True, - rooting="force-rooted") + data=tree_lines, + schema="newick", + extract_comment_metadata=True, + rooting="force-rooted", + ) if len(trees) == 0: raise ValueError("No valid trees in ms file") @@ -55,22 +59,28 @@ def from_ms(string): spans = [] for i, tree in enumerate(trees): try: - spans.append(float(tree.comments[0])) # the initial [X] is the first comment + spans.append( + float(tree.comments[0]) + ) # the initial [X] is the first comment except (ValueError, IndexError): raise ValueError( - "Problem reading integer # of positions spanned in tree {}".format(i) + - " (in ms format this preceeds the tree, in square braces)") + "Problem reading integer # of positions spanned in tree {}".format(i) + + " (in ms format this preceeds the tree, in square braces)" + ) if len(trees.taxon_namespace) != sum([1 for l in tree.leaf_node_iter()]): raise ValueError( "Tree {} does not have all {} expected tips".format( - i, len(trees.taxon_namespace))) + i, len(trees.taxon_namespace) + ) + ) # below we might want to set is_force_max_age=True to allow ancient tips # and work around branch length precision errors in ms tree.calc_node_ages() node_ages = [n.age for n in tree.ageorder_node_iter(include_leaves=False)] if len(set(node_ages)) != len(node_ages): raise ValueError( - "Tree {}: cannot have two internal nodes with the same time".format(i)) + "Tree {}: cannot have two internal nodes with the same time".format(i) + ) # NB: here we could check that the sequence_length == nsites, where nsites is given # in the ms_line, as the second number following the -r switch @@ -102,6 +112,73 @@ def from_ms(string): return tables.tree_sequence() +def from_beast(string, precision=6): + """ + Reads in a BEAST nexus description. + """ + trees = dendropy.TreeList.get(data=string, schema="nexus") + + print("Got", len(trees)) + positions = [] + for i, tree in enumerate(trees): + label_toks = tree.label.split() + assert label_toks[0] == "STATE" + positions.append(int(label_toks[1])) + # TODO likely some parameter tweaks would be needed to make this work reliably + tree.calc_node_ages() + # FIXME assuming for now what all samples are contemporary. + node_ages = [ + round(n.age, precision) + for n in tree.ageorder_node_iter(include_leaves=False) + ] + if len(set(node_ages)) != len(node_ages): + raise ValueError( + "Tree {}: cannot have two internal nodes with the same time".format(i) + ) + + assert len(positions) > 1 + assert positions[0] == 0 + diff = positions[-1] - positions[-2] + positions.append(positions[-1] + diff) + + tables = tskit.TableCollection(positions[-1]) + sample_id_map = {} + # Get the samples from the first tree. + for node in trees[0].leaf_node_iter(): + node_id = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=node.age) + sample_id_map[node.taxon.label] = node_id + + age_id_map = {} + left = 0 + for right, tree in zip(positions[1:], trees): + for node in tree.ageorder_node_iter(include_leaves=False): + age = round(node.age, precision) + # print("node", age, node.annotations) + # print(node.description()) + children = list(node.child_nodes()) + if node.age not in age_id_map: + age_id_map[age] = tables.nodes.add_row(flags=0, time=age) + parent_id = age_id_map[age] + for child in children: + if child.is_leaf(): + child_id = sample_id_map[child.taxon.label] + else: + child_age = round(child.age, precision) + child_id = age_id_map[child_age] + tables.edges.add_row(left, right, parent_id, child_id) + left = right + + # import numpy as np + # print(tables.nodes) + # t = tables.nodes.time + # np.sort(t) + # print(t) + tables.sort() + # Simplify will squash together any edges, removing redundancy. + tables.simplify() + return tables.tree_sequence() + + def to_ms(ts): """ Returns an ms-formatted version of the specified tree sequence.