Skip to content

Commit

Permalink
start BEAUti support for tipsampling #20
Browse files Browse the repository at this point in the history
  • Loading branch information
rbouckaert committed Mar 15, 2023
1 parent 88d6a97 commit a56a078
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 0 deletions.
123 changes: 123 additions & 0 deletions src/sa/beauti/SAMRCAPriorInputEditor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package sa.beauti;

import java.lang.reflect.InvocationTargetException;
import java.util.List;

import beast.base.core.BEASTInterface;
import beast.base.core.Input;
import beast.base.evolution.alignment.TaxonSet;
import beast.base.inference.Operator;
import beastfx.app.inputeditor.BeautiDoc;
import beastfx.app.inputeditor.BooleanInputEditor;
import beastfx.app.inputeditor.InputEditor;
import beastfx.app.inputeditor.MRCAPriorInputEditor;
import javafx.scene.Node;
import javafx.scene.control.CheckBox;
import javafx.scene.layout.HBox;
import sa.evolution.operators.SampledNodeDateRandomWalker;
import sa.math.distributions.SAMRCAPrior;

public class SAMRCAPriorInputEditor extends MRCAPriorInputEditor {

public SAMRCAPriorInputEditor(BeautiDoc doc) {
super(doc);
}

public SAMRCAPriorInputEditor() {
super();
}

@Override
public Class<?> type() {
return SAMRCAPrior.class;
}

//InputEditor tipsonlyEditor;

public InputEditor createTipsonlyEditor() throws NoSuchMethodException, SecurityException, ClassNotFoundException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
BooleanInputEditor e = new BooleanInputEditor (doc) {

@Override
public void init(Input<?> input, BEASTInterface beastObject, int itemNr, ExpandOption isExpandOption,
boolean addButtons) {
super.init(input, beastObject, itemNr, isExpandOption, addButtons);
for (Node o1 : getChildren()) {
if (o1 instanceof HBox) {
for (Node o : ((HBox)o1).getChildren()) {
if (o instanceof CheckBox) {
((CheckBox)o).setOnAction(e -> {
CheckBox src = (CheckBox) e.getSource();
if (src.isSelected()) {
enableTipSampling();
} else {
disableTipSampling(m_beastObject, doc);
}
});
}
}
}
}
}
};

SAMRCAPrior prior = (SAMRCAPrior) m_beastObject;
Input<?> input = prior.onlyUseTipsInput;
e.init(input, prior, -1, ExpandOption.FALSE, false);
return e;
}

// add TipDatesRandomWalker (if not present) and add to list of operators
private void enableTipSampling() {
// First, create/find the operator
SampledNodeDateRandomWalker operator = null;
SAMRCAPrior prior = (SAMRCAPrior) m_beastObject;
TaxonSet taxonset = prior.taxonsetInput.get();
taxonset.initAndValidate();

// see if an old operator still hangs around -- happens when toggling the TipsOnly checkbox a few times
for (BEASTInterface o : taxonset.getOutputs()) {
if (o instanceof SampledNodeDateRandomWalker) {
operator = (SampledNodeDateRandomWalker) o;
}
}

if (operator == null) {
operator = new SampledNodeDateRandomWalker();
operator.initByName("tree", prior.treeInput.get(), "taxonset", taxonset, "windowSize", 1.0, "weight", 1.0);
}
operator.setID("tipDatesSampler." + taxonset.getID());

doc.mcmc.get().setInputValue("operator", operator);
}

// remove TipDatesRandomWalker from list of operators
private static void disableTipSampling(BEASTInterface m_beastObject, BeautiDoc doc) {
// First, find the operator
SampledNodeDateRandomWalker operator = null;
SAMRCAPrior prior = (SAMRCAPrior) m_beastObject;
TaxonSet taxonset = prior.taxonsetInput.get();

// We cannot rely on the operator ID created in enableTipSampling()
// since the taxoneset name may have changed.
// However, if there is an TipDatesRandomWalker with taxonset as input, we want to remove it.
for (BEASTInterface o : taxonset.getOutputs()) {
if (o instanceof SampledNodeDateRandomWalker) {
operator = (SampledNodeDateRandomWalker) o;
}
}

if (operator == null) {
// should never happen
return;
}

// remove from list of operators
Object o = doc.mcmc.get().getInput("operator");
if (o instanceof Input<?>) {
Input<List<Operator>> operatorInput = (Input<List<Operator>>) o;
List<Operator> operators = operatorInput.get();
operators.remove(operator);
}
}

}
157 changes: 157 additions & 0 deletions src/sa/beauti/SAMRCAPriorProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package sa.beauti;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import beast.base.evolution.alignment.Taxon;
import beast.base.evolution.alignment.TaxonSet;
import beast.base.evolution.tree.Tree;
import beast.base.inference.Distribution;
import beast.base.inference.Logger;
import beast.base.inference.State;
import beast.base.inference.StateNode;
import beast.base.inference.distribution.OneOnX;
import beastfx.app.beauti.PriorListInputEditor;
import beastfx.app.beauti.PriorProvider;
import beastfx.app.inputeditor.BEASTObjectPanel;
import beastfx.app.inputeditor.BeautiDoc;
import beastfx.app.inputeditor.TaxonSetDialog;
import beastfx.app.util.Alert;
import sa.math.distributions.SAMRCAPrior;

public class SAMRCAPriorProvider implements PriorProvider {

@Override
public List<Distribution> createDistribution(BeautiDoc doc) {
SAMRCAPrior prior = new SAMRCAPrior();
try {

List<Tree> trees = new ArrayList<>();
doc.scrubAll(true, false);
State state = (State) doc.pluginmap.get("state");
for (StateNode node : state.stateNodeInput.get()) {
if (node instanceof Tree) { // && ((Tree) node).m_initial.get() != null) {
trees.add((Tree) node);
}
}
int treeIndex = 0;
if (trees.size() > 1) {
String[] treeIDs = new String[trees.size()];
for (int j = 0; j < treeIDs.length; j++) {
treeIDs[j] = trees.get(j).getID();
}
String treeID = (String) Alert.showInputDialog(null, "Select a tree", "MRCA selector", Alert.QUESTION_MESSAGE, null, treeIDs, trees.get(0));
treeIndex = 0;
while (treeIndex < treeIDs.length && !treeIDs[treeIndex].equals(treeID)) {
treeIndex++;
}
if (treeIndex == treeIDs.length) {
treeIndex = -1;
}
}
if (treeIndex < 0) {
return null;
}
prior.treeInput.setValue(trees.get(treeIndex), prior);
TaxonSet taxonSet = new TaxonSet();

TaxonSetDialog dlg = new TaxonSetDialog(taxonSet, getTaxonCandidates(prior, doc), doc);
if (!dlg.showDialog() || dlg.taxonSet.getID() == null || dlg.taxonSet.getID().trim().equals("")) {
return null;
}
taxonSet = dlg.taxonSet;
if (taxonSet.taxonsetInput.get().size() == 0) {
Alert.showMessageDialog(doc.beauti, "At least one taxon should be included in the taxon set",
"Error specifying taxon set", Alert.ERROR_MESSAGE);
return null;
}
int i = 1;
String id = taxonSet.getID();
while (doc.pluginmap.containsKey(taxonSet.getID()) && doc.pluginmap.get(taxonSet.getID()) != taxonSet) {
taxonSet.setID(id + i);
i++;
}
BEASTObjectPanel.addPluginToMap(taxonSet, doc);
prior.taxonsetInput.setValue(taxonSet, prior);
prior.setID(taxonSet.getID() + ".prior");
// this sets up the type
prior.distInput.setValue(new OneOnX(), prior);
// this removes the parametric distribution
prior.distInput.setValue(null, prior);

Logger logger = (Logger) doc.pluginmap.get("tracelog");
logger.loggersInput.setValue(prior, logger);
} catch (Exception e) {
// TODO: handle exception
}
List<Distribution> selectedPlugins = new ArrayList<>();
selectedPlugins.add(prior);
PriorListInputEditor.addCollapsedID(prior.getID());
return selectedPlugins;
}


/* expect args to be TaxonSet, Distribution, tree partition (if any) */
@Override
public List<Distribution> createDistribution(BeautiDoc doc, List<Object> args) {
SAMRCAPrior prior = new SAMRCAPrior();
TaxonSet taxonSet = (TaxonSet) args.get(0);
BEASTObjectPanel.addPluginToMap(taxonSet, doc);
prior.taxonsetInput.setValue(taxonSet, prior);
prior.setID(taxonSet.getID() + ".prior");
// this removes the parametric distribution
prior.distInput.setValue(args.get(1), prior);

Logger logger = (Logger) doc.pluginmap.get("tracelog");
logger.loggersInput.setValue(prior, logger);

if (args.size() <= 2) {
doc.scrubAll(true, false);
State state = (State) doc.pluginmap.get("state");
for (StateNode node : state.stateNodeInput.get()) {
if (node instanceof Tree) {
prior.treeInput.setValue(node, prior);
break;
}
}
} else {
Object tree = doc.pluginmap.get("Tree.t:" + args.get(2));
prior.treeInput.setValue(tree, prior);
}

List<Distribution> selectedPlugins = new ArrayList<>();
selectedPlugins.add(prior);
return selectedPlugins;
}

@Override
public String getDescription() {
return "Sampled Ancestors MRCA prior";
}


private Set<Taxon> getTaxonCandidates(SAMRCAPrior prior, BeautiDoc doc) {
Set<Taxon> candidates = new HashSet<>();
Tree tree = prior.treeInput.get();
String [] taxa = null;
if (tree.m_taxonset.get() != null) {
try {
TaxonSet set = tree.m_taxonset.get();
set.initAndValidate();
taxa = set.asStringList().toArray(new String[0]);
} catch (Exception e) {
taxa = prior.treeInput.get().getTaxaNames();
}
} else {
taxa = prior.treeInput.get().getTaxaNames();
}

for (String taxon : taxa) {
candidates.add(doc.getTaxon(taxon));
}
return candidates;
}

}
9 changes: 9 additions & 0 deletions src/sa/math/distributions/SAMRCAPrior.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package sa.math.distributions;

import beast.base.core.Description;
import beast.base.evolution.tree.MRCAPrior;

@Description("Behaves the same as a MRCAPrior, but allows BEAUti to know how to add the correct operators for tips sampling")
public class SAMRCAPrior extends MRCAPrior {

}
10 changes: 10 additions & 0 deletions version.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
args=""
/>

<service type="beastfx.app.inputeditor.InputEditor">
<provider classname="sa.beauti.SAMRCAPriorInputEditor"/>
</service>

<service type="beastfx.app.beauti.PriorProvider">
<provider classname="sa.beauti.SAMRCAPriorProvider"/>
</service>


<service type="beast.base.core.BEASTInterface">
<provider classname="sa.app.tools.FullToExtantTreeConverter"/>
<provider classname="sa.app.tools.SampledAncestorTreeAnalyser"/>
Expand Down Expand Up @@ -44,6 +53,7 @@
<provider classname="sa.math.distributions.DegenerateBeta"/>
<provider classname="sa.math.distributions.DegenerateUniform"/>
<provider classname="sa.math.distributions.SpecialMRCAPrior"/>
<provider classname="sa.math.distributions.SAMRCAPrior"/>
<provider classname="sa.util.ClusterZBSATree"/>
<provider classname="sa.util.ZeroBranchSATreeParser"/>
</service>
Expand Down

0 comments on commit a56a078

Please sign in to comment.