-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
54e72ca
commit 3efe9ae
Showing
4 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
lphy-base/src/main/java/lphy/base/distribution/Multinomial.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
package lphy.base.distribution; | ||
|
||
import lphy.core.model.GenerativeDistribution; | ||
import lphy.core.model.RandomVariable; | ||
import lphy.core.model.Value; | ||
import lphy.core.model.annotation.ParameterInfo; | ||
|
||
import java.util.Arrays; | ||
import java.util.Map; | ||
import java.util.TreeMap; | ||
|
||
public class Multinomial implements GenerativeDistribution<Integer[]> { | ||
|
||
private Value<Integer> n; | ||
private Value<Double[]> p; | ||
private Value<Double[]> q; | ||
|
||
public Multinomial( | ||
@ParameterInfo(name = DistributionConstants.nParamName, description = "number of trials.") Value<Integer> n, | ||
@ParameterInfo(name = DistributionConstants.pParamName, description = "event probabilities.") Value<Double[]> p) { | ||
super(); | ||
this.n = n; | ||
this.p = p; | ||
|
||
|
||
} | ||
|
||
public Multinomial(){} | ||
|
||
|
||
@Override | ||
public RandomVariable<Integer[]> sample() { | ||
//org.apache.mahout.math.random.Multinomial multinomial = new org.apache.mahout.math.random.Multinomial(); | ||
Double[] q1 = new Double[this.p.value().length]; | ||
double cum_prob = 1.0; | ||
for (int i = 0; i < this.p.value().length; i++) { | ||
if (p.value()[i] == 0.0) | ||
q1[i] = 0.0; | ||
else { | ||
q1[i] = this.p.value()[i] / cum_prob; | ||
cum_prob = cum_prob - this.p.value()[i]; | ||
} | ||
} | ||
q = new Value<>("q", q1); | ||
int sampleSize = n.value(); | ||
Integer[] result = new Integer[p.value().length]; | ||
for (int i = 0; i < p.value().length-1; i++) { | ||
Value<Double> pro = new Value<Double>("pro", this.q.value()[i]); | ||
Value<Integer> sampleS = new Value<>("sample", sampleSize); | ||
Binomial binomial = new Binomial(pro, sampleS); | ||
int rand = binomial.sample().value(); | ||
result[i] = rand; | ||
sampleSize -= rand; | ||
|
||
if (sampleSize == 0) { | ||
Arrays.fill(result, i + 1, result.length, 0); | ||
break; | ||
} | ||
} | ||
|
||
result[p.value().length-1] = sampleSize; | ||
return new RandomVariable<>(null, result, this); | ||
} | ||
|
||
|
||
@Override | ||
public Map<String, Value> getParams() { | ||
return new TreeMap<>() {{ | ||
put(DistributionConstants.nParamName, n); | ||
put(DistributionConstants.pParamName, p); | ||
}}; | ||
} | ||
|
||
@Override | ||
public void setParam(String paramName, Value value) { | ||
switch (paramName) { | ||
case DistributionConstants.nParamName: | ||
n = value; | ||
break; | ||
case DistributionConstants.pParamName: | ||
p = value; | ||
break; | ||
default: | ||
throw new RuntimeException("Unrecognised parameter name: " + paramName); | ||
} | ||
} | ||
|
||
|
||
|
||
|
||
|
||
public String toString() { | ||
return getName(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
lphy-base/src/test/java/lphy/base/distribution/MultinormalTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package lphy.base.distribution; | ||
|
||
import lphy.core.model.Value; | ||
import org.apache.commons.math3.stat.descriptive.moment.Mean; | ||
import org.apache.commons.math3.stat.descriptive.moment.Variance; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
public class MultinormalTest { | ||
|
||
public double DELTA_MEAN = 1.0; | ||
public double DELTA_VARIANCE = 1000.0; | ||
|
||
/** | ||
* testing Multinomial moments | ||
* E(Xi) = n * pi | ||
* Var(Xi) = n * pi * (1 - pi) | ||
*/ | ||
@Test | ||
public void testMultinomial() { | ||
int nReplicates = 100000; | ||
Value<Integer> n = new Value<>("n", 100000); | ||
Double[] prob = {0.3, 0.2, 0.4, 0.1}; | ||
Value<Double[]> p = new Value<>("p", prob); | ||
Multinomial multinomial = new Multinomial(); | ||
multinomial.setParam("n", n); | ||
multinomial.setParam("p", p); | ||
|
||
int k = prob.length; | ||
double[][] results = new double[k][nReplicates]; | ||
double[] expectedMean = new double[k]; | ||
double[] expectedVariance = new double[k]; | ||
|
||
Value<Integer[]> result; | ||
for (int j = 0; j < nReplicates; j++) { | ||
result = multinomial.sample(); | ||
for (int i = 0; i < k; i++) { | ||
// expected mean | ||
expectedMean[i] = n.value() * prob[i]; | ||
expectedVariance[i] = expectedMean[i] * (1 - prob[i]); | ||
results[i][j] = (double) (result.value()[i]); | ||
} | ||
} | ||
for (int i = 0; i < k; i++) { | ||
Mean mean = new Mean(); | ||
double observedMean = mean.evaluate(results[i], 0, nReplicates); | ||
assertEquals(expectedMean[i], observedMean, DELTA_MEAN); | ||
// System.out.println("expectedMean: " + expectedMean[i] + ", expectedVariance: " + expectedVariance[i]); | ||
// System.out.println("mean = " + observedMean); | ||
Variance variance = new Variance(); | ||
double observedVariance = variance.evaluate(results[i], 0, nReplicates); | ||
// System.out.println("var = " + observedVariance); | ||
assertEquals(expectedVariance[i], observedVariance, DELTA_VARIANCE); | ||
} | ||
|
||
|
||
|
||
|
||
|
||
Double[] prob2 = {0.3, 0.2, 0.2, 0.3}; | ||
Value<Integer> n2 = new Value<>("n", 1000); | ||
Value<Double[]> p2 = new Value<>("p", prob2); | ||
multinomial.setParam("n", new Value<>("n", n2.value())); | ||
multinomial.setParam("p", new Value<>("p", p2.value())); | ||
int k2 = prob2.length; | ||
double[][] results2 = new double[k2][nReplicates]; | ||
double[] expectedMean2 = new double[k2]; | ||
double[] expectedVariance2 = new double[k2]; | ||
Value<Integer[]> result2; | ||
for (int j = 0; j < nReplicates; j++) { | ||
result2 = multinomial.sample(); | ||
for (int i = 0; i < k2; i++) { | ||
// expected mean | ||
expectedMean2[i] = n2.value() * prob2[i]; | ||
expectedVariance2[i] = expectedMean2[i] * (1 - prob2[i]); | ||
results2[i][j] = (double) (result2.value()[i]); | ||
} | ||
} | ||
for (int i = 0; i < k2; i++) { | ||
Mean mean2 = new Mean(); | ||
double observedMean2 = mean2.evaluate(results2[i], 0, nReplicates); | ||
assertEquals(expectedMean2[i], observedMean2, DELTA_MEAN); | ||
//System.out.println("expectedMean: " + expectedMean2[i] + ", expectedVariance: " + expectedVariance2[i]); | ||
//System.out.println("mean = " + observedMean2); | ||
Variance variance2 = new Variance(); | ||
double observedVariance2 = variance2.evaluate(results2[i], 0, nReplicates); | ||
//System.out.println("var = " + observedVariance2); | ||
assertEquals(expectedVariance2[i], observedVariance2, DELTA_VARIANCE); | ||
} | ||
|
||
|
||
|
||
result = multinomial.sample(); | ||
System.out.println(result); | ||
|
||
|
||
|
||
|
||
} | ||
|
||
} |