From a56a078cbe0de37df6cbe3efcfdde5f01adef1e9 Mon Sep 17 00:00:00 2001 From: rbouckaert Date: Thu, 16 Mar 2023 12:04:13 +1300 Subject: [PATCH] start BEAUti support for tipsampling #20 --- src/sa/beauti/SAMRCAPriorInputEditor.java | 123 ++++++++++++++++ src/sa/beauti/SAMRCAPriorProvider.java | 157 +++++++++++++++++++++ src/sa/math/distributions/SAMRCAPrior.java | 9 ++ version.xml | 10 ++ 4 files changed, 299 insertions(+) create mode 100644 src/sa/beauti/SAMRCAPriorInputEditor.java create mode 100644 src/sa/beauti/SAMRCAPriorProvider.java create mode 100644 src/sa/math/distributions/SAMRCAPrior.java diff --git a/src/sa/beauti/SAMRCAPriorInputEditor.java b/src/sa/beauti/SAMRCAPriorInputEditor.java new file mode 100644 index 0000000..d7165d9 --- /dev/null +++ b/src/sa/beauti/SAMRCAPriorInputEditor.java @@ -0,0 +1,123 @@ +package sa.beauti; + +import java.lang.reflect.InvocationTargetException; +import java.util.List; + +import beast.base.core.BEASTInterface; +import beast.base.core.Input; +import beast.base.evolution.alignment.TaxonSet; +import beast.base.inference.Operator; +import beastfx.app.inputeditor.BeautiDoc; +import beastfx.app.inputeditor.BooleanInputEditor; +import beastfx.app.inputeditor.InputEditor; +import beastfx.app.inputeditor.MRCAPriorInputEditor; +import javafx.scene.Node; +import javafx.scene.control.CheckBox; +import javafx.scene.layout.HBox; +import sa.evolution.operators.SampledNodeDateRandomWalker; +import sa.math.distributions.SAMRCAPrior; + +public class SAMRCAPriorInputEditor extends MRCAPriorInputEditor { + + public SAMRCAPriorInputEditor(BeautiDoc doc) { + super(doc); + } + + public SAMRCAPriorInputEditor() { + super(); + } + + @Override + public Class type() { + return SAMRCAPrior.class; + } + + //InputEditor tipsonlyEditor; + + public InputEditor createTipsonlyEditor() throws NoSuchMethodException, SecurityException, ClassNotFoundException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException { + BooleanInputEditor e = new BooleanInputEditor (doc) { + + @Override + public void init(Input input, BEASTInterface beastObject, int itemNr, ExpandOption isExpandOption, + boolean addButtons) { + super.init(input, beastObject, itemNr, isExpandOption, addButtons); + for (Node o1 : getChildren()) { + if (o1 instanceof HBox) { + for (Node o : ((HBox)o1).getChildren()) { + if (o instanceof CheckBox) { + ((CheckBox)o).setOnAction(e -> { + CheckBox src = (CheckBox) e.getSource(); + if (src.isSelected()) { + enableTipSampling(); + } else { + disableTipSampling(m_beastObject, doc); + } + }); + } + } + } + } + } + }; + + SAMRCAPrior prior = (SAMRCAPrior) m_beastObject; + Input input = prior.onlyUseTipsInput; + e.init(input, prior, -1, ExpandOption.FALSE, false); + return e; + } + + // add TipDatesRandomWalker (if not present) and add to list of operators + private void enableTipSampling() { + // First, create/find the operator + SampledNodeDateRandomWalker operator = null; + SAMRCAPrior prior = (SAMRCAPrior) m_beastObject; + TaxonSet taxonset = prior.taxonsetInput.get(); + taxonset.initAndValidate(); + + // see if an old operator still hangs around -- happens when toggling the TipsOnly checkbox a few times + for (BEASTInterface o : taxonset.getOutputs()) { + if (o instanceof SampledNodeDateRandomWalker) { + operator = (SampledNodeDateRandomWalker) o; + } + } + + if (operator == null) { + operator = new SampledNodeDateRandomWalker(); + operator.initByName("tree", prior.treeInput.get(), "taxonset", taxonset, "windowSize", 1.0, "weight", 1.0); + } + operator.setID("tipDatesSampler." + taxonset.getID()); + + doc.mcmc.get().setInputValue("operator", operator); + } + + // remove TipDatesRandomWalker from list of operators + private static void disableTipSampling(BEASTInterface m_beastObject, BeautiDoc doc) { + // First, find the operator + SampledNodeDateRandomWalker operator = null; + SAMRCAPrior prior = (SAMRCAPrior) m_beastObject; + TaxonSet taxonset = prior.taxonsetInput.get(); + + // We cannot rely on the operator ID created in enableTipSampling() + // since the taxoneset name may have changed. + // However, if there is an TipDatesRandomWalker with taxonset as input, we want to remove it. + for (BEASTInterface o : taxonset.getOutputs()) { + if (o instanceof SampledNodeDateRandomWalker) { + operator = (SampledNodeDateRandomWalker) o; + } + } + + if (operator == null) { + // should never happen + return; + } + + // remove from list of operators + Object o = doc.mcmc.get().getInput("operator"); + if (o instanceof Input) { + Input> operatorInput = (Input>) o; + List operators = operatorInput.get(); + operators.remove(operator); + } + } + +} diff --git a/src/sa/beauti/SAMRCAPriorProvider.java b/src/sa/beauti/SAMRCAPriorProvider.java new file mode 100644 index 0000000..78a208d --- /dev/null +++ b/src/sa/beauti/SAMRCAPriorProvider.java @@ -0,0 +1,157 @@ +package sa.beauti; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import beast.base.evolution.alignment.Taxon; +import beast.base.evolution.alignment.TaxonSet; +import beast.base.evolution.tree.Tree; +import beast.base.inference.Distribution; +import beast.base.inference.Logger; +import beast.base.inference.State; +import beast.base.inference.StateNode; +import beast.base.inference.distribution.OneOnX; +import beastfx.app.beauti.PriorListInputEditor; +import beastfx.app.beauti.PriorProvider; +import beastfx.app.inputeditor.BEASTObjectPanel; +import beastfx.app.inputeditor.BeautiDoc; +import beastfx.app.inputeditor.TaxonSetDialog; +import beastfx.app.util.Alert; +import sa.math.distributions.SAMRCAPrior; + +public class SAMRCAPriorProvider implements PriorProvider { + + @Override + public List createDistribution(BeautiDoc doc) { + SAMRCAPrior prior = new SAMRCAPrior(); + try { + + List trees = new ArrayList<>(); + doc.scrubAll(true, false); + State state = (State) doc.pluginmap.get("state"); + for (StateNode node : state.stateNodeInput.get()) { + if (node instanceof Tree) { // && ((Tree) node).m_initial.get() != null) { + trees.add((Tree) node); + } + } + int treeIndex = 0; + if (trees.size() > 1) { + String[] treeIDs = new String[trees.size()]; + for (int j = 0; j < treeIDs.length; j++) { + treeIDs[j] = trees.get(j).getID(); + } + String treeID = (String) Alert.showInputDialog(null, "Select a tree", "MRCA selector", Alert.QUESTION_MESSAGE, null, treeIDs, trees.get(0)); + treeIndex = 0; + while (treeIndex < treeIDs.length && !treeIDs[treeIndex].equals(treeID)) { + treeIndex++; + } + if (treeIndex == treeIDs.length) { + treeIndex = -1; + } + } + if (treeIndex < 0) { + return null; + } + prior.treeInput.setValue(trees.get(treeIndex), prior); + TaxonSet taxonSet = new TaxonSet(); + + TaxonSetDialog dlg = new TaxonSetDialog(taxonSet, getTaxonCandidates(prior, doc), doc); + if (!dlg.showDialog() || dlg.taxonSet.getID() == null || dlg.taxonSet.getID().trim().equals("")) { + return null; + } + taxonSet = dlg.taxonSet; + if (taxonSet.taxonsetInput.get().size() == 0) { + Alert.showMessageDialog(doc.beauti, "At least one taxon should be included in the taxon set", + "Error specifying taxon set", Alert.ERROR_MESSAGE); + return null; + } + int i = 1; + String id = taxonSet.getID(); + while (doc.pluginmap.containsKey(taxonSet.getID()) && doc.pluginmap.get(taxonSet.getID()) != taxonSet) { + taxonSet.setID(id + i); + i++; + } + BEASTObjectPanel.addPluginToMap(taxonSet, doc); + prior.taxonsetInput.setValue(taxonSet, prior); + prior.setID(taxonSet.getID() + ".prior"); + // this sets up the type + prior.distInput.setValue(new OneOnX(), prior); + // this removes the parametric distribution + prior.distInput.setValue(null, prior); + + Logger logger = (Logger) doc.pluginmap.get("tracelog"); + logger.loggersInput.setValue(prior, logger); + } catch (Exception e) { + // TODO: handle exception + } + List selectedPlugins = new ArrayList<>(); + selectedPlugins.add(prior); + PriorListInputEditor.addCollapsedID(prior.getID()); + return selectedPlugins; + } + + + /* expect args to be TaxonSet, Distribution, tree partition (if any) */ + @Override + public List createDistribution(BeautiDoc doc, List args) { + SAMRCAPrior prior = new SAMRCAPrior(); + TaxonSet taxonSet = (TaxonSet) args.get(0); + BEASTObjectPanel.addPluginToMap(taxonSet, doc); + prior.taxonsetInput.setValue(taxonSet, prior); + prior.setID(taxonSet.getID() + ".prior"); + // this removes the parametric distribution + prior.distInput.setValue(args.get(1), prior); + + Logger logger = (Logger) doc.pluginmap.get("tracelog"); + logger.loggersInput.setValue(prior, logger); + + if (args.size() <= 2) { + doc.scrubAll(true, false); + State state = (State) doc.pluginmap.get("state"); + for (StateNode node : state.stateNodeInput.get()) { + if (node instanceof Tree) { + prior.treeInput.setValue(node, prior); + break; + } + } + } else { + Object tree = doc.pluginmap.get("Tree.t:" + args.get(2)); + prior.treeInput.setValue(tree, prior); + } + + List selectedPlugins = new ArrayList<>(); + selectedPlugins.add(prior); + return selectedPlugins; + } + + @Override + public String getDescription() { + return "Sampled Ancestors MRCA prior"; + } + + + private Set getTaxonCandidates(SAMRCAPrior prior, BeautiDoc doc) { + Set candidates = new HashSet<>(); + Tree tree = prior.treeInput.get(); + String [] taxa = null; + if (tree.m_taxonset.get() != null) { + try { + TaxonSet set = tree.m_taxonset.get(); + set.initAndValidate(); + taxa = set.asStringList().toArray(new String[0]); + } catch (Exception e) { + taxa = prior.treeInput.get().getTaxaNames(); + } + } else { + taxa = prior.treeInput.get().getTaxaNames(); + } + + for (String taxon : taxa) { + candidates.add(doc.getTaxon(taxon)); + } + return candidates; + } + +} diff --git a/src/sa/math/distributions/SAMRCAPrior.java b/src/sa/math/distributions/SAMRCAPrior.java new file mode 100644 index 0000000..a719d22 --- /dev/null +++ b/src/sa/math/distributions/SAMRCAPrior.java @@ -0,0 +1,9 @@ +package sa.math.distributions; + +import beast.base.core.Description; +import beast.base.evolution.tree.MRCAPrior; + +@Description("Behaves the same as a MRCAPrior, but allows BEAUti to know how to add the correct operators for tips sampling") +public class SAMRCAPrior extends MRCAPrior { + +} diff --git a/version.xml b/version.xml index 2f3bed5..70ff0ba 100644 --- a/version.xml +++ b/version.xml @@ -13,6 +13,15 @@ args="" /> + + + + + + + + + @@ -44,6 +53,7 @@ +