Skip to content

Commit

Permalink
Added Multinomial #463
Browse files Browse the repository at this point in the history
  • Loading branch information
zjzxiaohei committed Apr 22, 2024
1 parent 54e72ca commit 3efe9ae
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ public class DistributionConstants {
public static final String pParamName = "p";
public static final String nParamName = "n";
public static final String rParamName = "r";
public static final String kParamName = "k";
public static final String shapeParamName = "shape";
public static final String alphaParamName = "alpha";
public static final String betaParamName = "beta";
Expand Down
95 changes: 95 additions & 0 deletions lphy-base/src/main/java/lphy/base/distribution/Multinomial.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package lphy.base.distribution;

import lphy.core.model.GenerativeDistribution;
import lphy.core.model.RandomVariable;
import lphy.core.model.Value;
import lphy.core.model.annotation.ParameterInfo;

import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;

public class Multinomial implements GenerativeDistribution<Integer[]> {

private Value<Integer> n;
private Value<Double[]> p;
private Value<Double[]> q;

public Multinomial(
@ParameterInfo(name = DistributionConstants.nParamName, description = "number of trials.") Value<Integer> n,
@ParameterInfo(name = DistributionConstants.pParamName, description = "event probabilities.") Value<Double[]> p) {
super();
this.n = n;
this.p = p;


}

public Multinomial(){}


@Override
public RandomVariable<Integer[]> sample() {
//org.apache.mahout.math.random.Multinomial multinomial = new org.apache.mahout.math.random.Multinomial();
Double[] q1 = new Double[this.p.value().length];
double cum_prob = 1.0;
for (int i = 0; i < this.p.value().length; i++) {
if (p.value()[i] == 0.0)
q1[i] = 0.0;
else {
q1[i] = this.p.value()[i] / cum_prob;
cum_prob = cum_prob - this.p.value()[i];
}
}
q = new Value<>("q", q1);
int sampleSize = n.value();
Integer[] result = new Integer[p.value().length];
for (int i = 0; i < p.value().length-1; i++) {
Value<Double> pro = new Value<Double>("pro", this.q.value()[i]);
Value<Integer> sampleS = new Value<>("sample", sampleSize);
Binomial binomial = new Binomial(pro, sampleS);
int rand = binomial.sample().value();
result[i] = rand;
sampleSize -= rand;

if (sampleSize == 0) {
Arrays.fill(result, i + 1, result.length, 0);
break;
}
}

result[p.value().length-1] = sampleSize;
return new RandomVariable<>(null, result, this);
}


@Override
public Map<String, Value> getParams() {
return new TreeMap<>() {{
put(DistributionConstants.nParamName, n);
put(DistributionConstants.pParamName, p);
}};
}

@Override
public void setParam(String paramName, Value value) {
switch (paramName) {
case DistributionConstants.nParamName:
n = value;
break;
case DistributionConstants.pParamName:
p = value;
break;
default:
throw new RuntimeException("Unrecognised parameter name: " + paramName);
}
}





public String toString() {
return getName();
}
}
1 change: 1 addition & 0 deletions lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public List<Class<? extends GenerativeDistribution>> declareDistributions() {
DiscretizedGamma.class, Exp.class, Gamma.class, Geometric.class, InverseGamma.class, LogNormal.class,
NegativeBinomial.class, Normal.class, NormalGamma.class, Poisson.class,
Uniform.class, UniformDiscrete.class, Weibull.class, WeightedDirichlet.class,
Multinomial.class,
// tree distribution
Yule.class, BirthDeathTree.class, FullBirthDeathTree.class, BirthDeathTreeDT.class,
BirthDeathSamplingTree.class, BirthDeathSamplingTreeDT.class, BirthDeathSerialSamplingTree.class,
Expand Down
102 changes: 102 additions & 0 deletions lphy-base/src/test/java/lphy/base/distribution/MultinormalTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package lphy.base.distribution;

import lphy.core.model.Value;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class MultinormalTest {

public double DELTA_MEAN = 1.0;
public double DELTA_VARIANCE = 1000.0;

/**
* testing Multinomial moments
* E(Xi) = n * pi
* Var(Xi) = n * pi * (1 - pi)
*/
@Test
public void testMultinomial() {
int nReplicates = 100000;
Value<Integer> n = new Value<>("n", 100000);
Double[] prob = {0.3, 0.2, 0.4, 0.1};
Value<Double[]> p = new Value<>("p", prob);
Multinomial multinomial = new Multinomial();
multinomial.setParam("n", n);
multinomial.setParam("p", p);

int k = prob.length;
double[][] results = new double[k][nReplicates];
double[] expectedMean = new double[k];
double[] expectedVariance = new double[k];

Value<Integer[]> result;
for (int j = 0; j < nReplicates; j++) {
result = multinomial.sample();
for (int i = 0; i < k; i++) {
// expected mean
expectedMean[i] = n.value() * prob[i];
expectedVariance[i] = expectedMean[i] * (1 - prob[i]);
results[i][j] = (double) (result.value()[i]);
}
}
for (int i = 0; i < k; i++) {
Mean mean = new Mean();
double observedMean = mean.evaluate(results[i], 0, nReplicates);
assertEquals(expectedMean[i], observedMean, DELTA_MEAN);
// System.out.println("expectedMean: " + expectedMean[i] + ", expectedVariance: " + expectedVariance[i]);
// System.out.println("mean = " + observedMean);
Variance variance = new Variance();
double observedVariance = variance.evaluate(results[i], 0, nReplicates);
// System.out.println("var = " + observedVariance);
assertEquals(expectedVariance[i], observedVariance, DELTA_VARIANCE);
}





Double[] prob2 = {0.3, 0.2, 0.2, 0.3};
Value<Integer> n2 = new Value<>("n", 1000);
Value<Double[]> p2 = new Value<>("p", prob2);
multinomial.setParam("n", new Value<>("n", n2.value()));
multinomial.setParam("p", new Value<>("p", p2.value()));
int k2 = prob2.length;
double[][] results2 = new double[k2][nReplicates];
double[] expectedMean2 = new double[k2];
double[] expectedVariance2 = new double[k2];
Value<Integer[]> result2;
for (int j = 0; j < nReplicates; j++) {
result2 = multinomial.sample();
for (int i = 0; i < k2; i++) {
// expected mean
expectedMean2[i] = n2.value() * prob2[i];
expectedVariance2[i] = expectedMean2[i] * (1 - prob2[i]);
results2[i][j] = (double) (result2.value()[i]);
}
}
for (int i = 0; i < k2; i++) {
Mean mean2 = new Mean();
double observedMean2 = mean2.evaluate(results2[i], 0, nReplicates);
assertEquals(expectedMean2[i], observedMean2, DELTA_MEAN);
//System.out.println("expectedMean: " + expectedMean2[i] + ", expectedVariance: " + expectedVariance2[i]);
//System.out.println("mean = " + observedMean2);
Variance variance2 = new Variance();
double observedVariance2 = variance2.evaluate(results2[i], 0, nReplicates);
//System.out.println("var = " + observedVariance2);
assertEquals(expectedVariance2[i], observedVariance2, DELTA_VARIANCE);
}



result = multinomial.sample();
System.out.println(result);




}

}

0 comments on commit 3efe9ae

Please sign in to comment.