Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ACRF-example-fixes #24

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/cc/mallet/grmm/examples/CrossTemplate1.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,26 @@ public class CrossTemplate1 extends ACRF.SequenceTemplate {
private int lvl1 = 0;
private int lvl2 = 1;

public CrossTemplate1 (int lvl1, int lvl2)
{
public CrossTemplate1(int lvl1, int lvl2) {
this.lvl1 = lvl1;
this.lvl2 = lvl2;
}

protected void addInstantiatedCliques (ACRF.UnrolledGraph graph, FeatureVectorSequence fvs, LabelsAssignment lblseq)
{
for (int t = 0; t < lblseq.size() - 1; t++) {
Variable var1 = lblseq.varOfIndex (t, lvl1);
Variable var2 = lblseq.varOfIndex (t + 1, lvl2);
assert var1 != null : "Couldn't get label factor "+lvl1+" time "+t;
assert var2 != null : "Couldn't get label factor "+lvl2+" time "+(t+1);
protected void addInstantiatedCliques(ACRF.UnrolledGraph graph, FeatureVectorSequence fvs, LabelsAssignment lblseq) {
for (int t = 0; t < lblseq.maxTime() - 1; t++) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be looping over maxTime -- not over size()

try {
Variable var1 = lblseq.varOfIndex(t, lvl1);
Variable var2 = lblseq.varOfIndex(t + 1, lvl2);
assert var2 != null : "Couldn't get label factor " + lvl2 + " time " + (t + 1);
assert var1 != null : "Couldn't get label factor " + lvl1 + " time " + t;

Variable[] vars = new Variable[] { var1, var2 };
FeatureVector fv = fvs.getFeatureVector (t);
ACRF.UnrolledVarSet vs = new ACRF.UnrolledVarSet (graph, this, vars, fv);
graph.addClique (vs);
Variable[] vars = new Variable[]{var1, var2};
FeatureVector fv = fvs.getFeatureVector(t);
ACRF.UnrolledVarSet vs = new ACRF.UnrolledVarSet(graph, this, vars, fv);
graph.addClique(vs);
} catch (ArrayIndexOutOfBoundsException e) {
throw e;
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/cc/mallet/grmm/examples/SimpleCrfExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static void main (String[] args) throws FileNotFoundException
true));

InstanceList testing = new InstanceList (pipe);
training.addThruPipe (new LineGroupIterator (new FileReader (testFile),
testing.addThruPipe (new LineGroupIterator (new FileReader (testFile),
Pattern.compile ("\\s*"),
true));

Expand Down
2 changes: 1 addition & 1 deletion src/cc/mallet/grmm/inference/TRP.java
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ public void computeMarginals (FactorGraph m)
dumpForIter (iter, tree);
}
iterUsed = iter;
logger.info ("TRP used " + iter + " iterations.");
logger.fine ("TRP used " + iter + " iterations.");

doneWithGraph (m);
}
Expand Down
30 changes: 19 additions & 11 deletions src/cc/mallet/grmm/learning/ACRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
* Class for Arbitrary CRFs. These are CRFs with completely
* arbitrary graphical structure. The user passes in a list
* of instances of ACRF.CliqueFactory, which get to look at
* the sequence and decide what
* the sequence and decide what
*
* @author <a href="mailto:[email protected]">Charles Sutton</a>
* @version $Id: ACRF.java,v 1.1 2007/10/22 21:37:43 mccallum Exp $
Expand Down Expand Up @@ -157,7 +157,7 @@ public void setGaussianPriorVariance (double gaussianPriorVariance)
{
this.gaussianPriorVariance = gaussianPriorVariance;
}


public void setGraphProcessor (GraphPostProcessor graphProcessor)
{
Expand Down Expand Up @@ -791,7 +791,7 @@ private void computeCPFs ()
addFactorInternal (clique, ptl);
clique.tmpl.modifyPotential (this, clique, ptl);
uvsMap.put (ptl, clique);

// sigh
LogTableFactor unif = new LogTableFactor (clique);
residTmp.add (Factors.distLinf (unif, ptl));
Expand Down Expand Up @@ -1823,17 +1823,25 @@ public void addInstantiatedCliques (ACRF.UnrolledGraph graph,
FeatureVectorSequence fvs,
LabelsAssignment lblseq)
{
if (lblseq.maxTime() == 1) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if only one time slice you still need to add a clique

Variable v1 = lblseq.varOfIndex(0, factor);
FeatureVector fv = fvs.getFeatureVector(0);
ACRF.UnrolledVarSet clique = new ACRF.UnrolledVarSet(graph, this, new Variable[]{v1}, fv);
graph.addClique(clique);
return;
}

for (int i = 0; i < lblseq.maxTime() - 1; i++) {
Variable v1 = lblseq.varOfIndex (i, factor);
Variable v2 = lblseq.varOfIndex (i + 1, factor);
FeatureVector fv = fvs.getFeatureVector (i);
Variable v1 = lblseq.varOfIndex(i, factor);
Variable v2 = lblseq.varOfIndex(i + 1, factor);
FeatureVector fv = fvs.getFeatureVector(i);

Variable[] vars = new Variable[] { v1, v2 };
assert v1 != null : "Couldn't get label factor "+factor+" time "+i;
assert v2 != null : "Couldn't get label factor "+factor+" time "+(i+1);
Variable[] vars = new Variable[]{v1, v2};
assert v1 != null : "Couldn't get label factor " + factor + " time " + i;
assert v2 != null : "Couldn't get label factor " + factor + " time " + (i + 1);

ACRF.UnrolledVarSet clique = new ACRF.UnrolledVarSet (graph, this, vars, fv);
graph.addClique (clique);
ACRF.UnrolledVarSet clique = new ACRF.UnrolledVarSet(graph, this, vars, fv);
graph.addClique(clique);
}
}

Expand Down