From 3efe9aeb3359f3c3247a9bb3041d40fe8508d7a7 Mon Sep 17 00:00:00 2001 From: zjzxiaohei <108013625+zjzxiaohei@users.noreply.github.com> Date: Mon, 22 Apr 2024 15:23:57 +1200 Subject: [PATCH] Added Multinomial #463 --- .../distribution/DistributionConstants.java | 1 + .../lphy/base/distribution/Multinomial.java | 95 ++++++++++++++++ .../main/java/lphy/base/spi/LPhyBaseImpl.java | 1 + .../base/distribution/MultinormalTest.java | 102 ++++++++++++++++++ 4 files changed, 199 insertions(+) create mode 100644 lphy-base/src/main/java/lphy/base/distribution/Multinomial.java create mode 100644 lphy-base/src/test/java/lphy/base/distribution/MultinormalTest.java diff --git a/lphy-base/src/main/java/lphy/base/distribution/DistributionConstants.java b/lphy-base/src/main/java/lphy/base/distribution/DistributionConstants.java index fa5529b45..e69294697 100644 --- a/lphy-base/src/main/java/lphy/base/distribution/DistributionConstants.java +++ b/lphy-base/src/main/java/lphy/base/distribution/DistributionConstants.java @@ -5,6 +5,7 @@ public class DistributionConstants { public static final String pParamName = "p"; public static final String nParamName = "n"; public static final String rParamName = "r"; + public static final String kParamName = "k"; public static final String shapeParamName = "shape"; public static final String alphaParamName = "alpha"; public static final String betaParamName = "beta"; diff --git a/lphy-base/src/main/java/lphy/base/distribution/Multinomial.java b/lphy-base/src/main/java/lphy/base/distribution/Multinomial.java new file mode 100644 index 000000000..c02961b81 --- /dev/null +++ b/lphy-base/src/main/java/lphy/base/distribution/Multinomial.java @@ -0,0 +1,95 @@ +package lphy.base.distribution; + +import lphy.core.model.GenerativeDistribution; +import lphy.core.model.RandomVariable; +import lphy.core.model.Value; +import lphy.core.model.annotation.ParameterInfo; + +import java.util.Arrays; +import java.util.Map; +import java.util.TreeMap; + +public class Multinomial implements GenerativeDistribution { + + private Value n; + private Value p; + private Value q; + + public Multinomial( + @ParameterInfo(name = DistributionConstants.nParamName, description = "number of trials.") Value n, + @ParameterInfo(name = DistributionConstants.pParamName, description = "event probabilities.") Value p) { + super(); + this.n = n; + this.p = p; + + + } + + public Multinomial(){} + + + @Override + public RandomVariable sample() { + //org.apache.mahout.math.random.Multinomial multinomial = new org.apache.mahout.math.random.Multinomial(); + Double[] q1 = new Double[this.p.value().length]; + double cum_prob = 1.0; + for (int i = 0; i < this.p.value().length; i++) { + if (p.value()[i] == 0.0) + q1[i] = 0.0; + else { + q1[i] = this.p.value()[i] / cum_prob; + cum_prob = cum_prob - this.p.value()[i]; + } + } + q = new Value<>("q", q1); + int sampleSize = n.value(); + Integer[] result = new Integer[p.value().length]; + for (int i = 0; i < p.value().length-1; i++) { + Value pro = new Value("pro", this.q.value()[i]); + Value sampleS = new Value<>("sample", sampleSize); + Binomial binomial = new Binomial(pro, sampleS); + int rand = binomial.sample().value(); + result[i] = rand; + sampleSize -= rand; + + if (sampleSize == 0) { + Arrays.fill(result, i + 1, result.length, 0); + break; + } + } + + result[p.value().length-1] = sampleSize; + return new RandomVariable<>(null, result, this); + } + + + @Override + public Map getParams() { + return new TreeMap<>() {{ + put(DistributionConstants.nParamName, n); + put(DistributionConstants.pParamName, p); + }}; + } + + @Override + public void setParam(String paramName, Value value) { + switch (paramName) { + case DistributionConstants.nParamName: + n = value; + break; + case DistributionConstants.pParamName: + p = value; + break; + default: + throw new RuntimeException("Unrecognised parameter name: " + paramName); + } + } + + + + + + public String toString() { + return getName(); + } + } diff --git a/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java b/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java index 98eee0a7a..4ddaafcc6 100644 --- a/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java +++ b/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java @@ -53,6 +53,7 @@ public List> declareDistributions() { DiscretizedGamma.class, Exp.class, Gamma.class, Geometric.class, InverseGamma.class, LogNormal.class, NegativeBinomial.class, Normal.class, NormalGamma.class, Poisson.class, Uniform.class, UniformDiscrete.class, Weibull.class, WeightedDirichlet.class, + Multinomial.class, // tree distribution Yule.class, BirthDeathTree.class, FullBirthDeathTree.class, BirthDeathTreeDT.class, BirthDeathSamplingTree.class, BirthDeathSamplingTreeDT.class, BirthDeathSerialSamplingTree.class, diff --git a/lphy-base/src/test/java/lphy/base/distribution/MultinormalTest.java b/lphy-base/src/test/java/lphy/base/distribution/MultinormalTest.java new file mode 100644 index 000000000..0995a91c8 --- /dev/null +++ b/lphy-base/src/test/java/lphy/base/distribution/MultinormalTest.java @@ -0,0 +1,102 @@ +package lphy.base.distribution; + +import lphy.core.model.Value; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MultinormalTest { + + public double DELTA_MEAN = 1.0; + public double DELTA_VARIANCE = 1000.0; + + /** + * testing Multinomial moments + * E(Xi) = n * pi + * Var(Xi) = n * pi * (1 - pi) + */ + @Test + public void testMultinomial() { + int nReplicates = 100000; + Value n = new Value<>("n", 100000); + Double[] prob = {0.3, 0.2, 0.4, 0.1}; + Value p = new Value<>("p", prob); + Multinomial multinomial = new Multinomial(); + multinomial.setParam("n", n); + multinomial.setParam("p", p); + + int k = prob.length; + double[][] results = new double[k][nReplicates]; + double[] expectedMean = new double[k]; + double[] expectedVariance = new double[k]; + + Value result; + for (int j = 0; j < nReplicates; j++) { + result = multinomial.sample(); + for (int i = 0; i < k; i++) { + // expected mean + expectedMean[i] = n.value() * prob[i]; + expectedVariance[i] = expectedMean[i] * (1 - prob[i]); + results[i][j] = (double) (result.value()[i]); + } + } + for (int i = 0; i < k; i++) { + Mean mean = new Mean(); + double observedMean = mean.evaluate(results[i], 0, nReplicates); + assertEquals(expectedMean[i], observedMean, DELTA_MEAN); +// System.out.println("expectedMean: " + expectedMean[i] + ", expectedVariance: " + expectedVariance[i]); +// System.out.println("mean = " + observedMean); + Variance variance = new Variance(); + double observedVariance = variance.evaluate(results[i], 0, nReplicates); +// System.out.println("var = " + observedVariance); + assertEquals(expectedVariance[i], observedVariance, DELTA_VARIANCE); + } + + + + + + Double[] prob2 = {0.3, 0.2, 0.2, 0.3}; + Value n2 = new Value<>("n", 1000); + Value p2 = new Value<>("p", prob2); + multinomial.setParam("n", new Value<>("n", n2.value())); + multinomial.setParam("p", new Value<>("p", p2.value())); + int k2 = prob2.length; + double[][] results2 = new double[k2][nReplicates]; + double[] expectedMean2 = new double[k2]; + double[] expectedVariance2 = new double[k2]; + Value result2; + for (int j = 0; j < nReplicates; j++) { + result2 = multinomial.sample(); + for (int i = 0; i < k2; i++) { + // expected mean + expectedMean2[i] = n2.value() * prob2[i]; + expectedVariance2[i] = expectedMean2[i] * (1 - prob2[i]); + results2[i][j] = (double) (result2.value()[i]); + } + } + for (int i = 0; i < k2; i++) { + Mean mean2 = new Mean(); + double observedMean2 = mean2.evaluate(results2[i], 0, nReplicates); + assertEquals(expectedMean2[i], observedMean2, DELTA_MEAN); + //System.out.println("expectedMean: " + expectedMean2[i] + ", expectedVariance: " + expectedVariance2[i]); + //System.out.println("mean = " + observedMean2); + Variance variance2 = new Variance(); + double observedVariance2 = variance2.evaluate(results2[i], 0, nReplicates); + //System.out.println("var = " + observedVariance2); + assertEquals(expectedVariance2[i], observedVariance2, DELTA_VARIANCE); + } + + + + result = multinomial.sample(); + System.out.println(result); + + + + + } + +}