From 70cf55cc7554b668e421a564af6674161bda692d Mon Sep 17 00:00:00 2001 From: CT Date: Fri, 17 Jan 2020 00:29:06 +0800 Subject: [PATCH 1/4] Implement the MiniBatchKMeansClusterer and compare to KMeansPlusPlusClusterer --- .../math4/ml/clustering/ClusterUtils.java | 96 ++++++ .../clustering/MiniBatchKMeansClusterer.java | 281 ++++++++++++++++++ .../initialization/CentroidInitializer.java | 22 ++ .../KMeansPlusPlusCentroidInitializer.java | 169 +++++++++++ .../RandomCentroidInitializer.java | 41 +++ .../MiniBatchKMeansClustererTest.java | 111 +++++++ 6 files changed, 720 insertions(+) create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java create mode 100644 src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java new file mode 100644 index 0000000000..9c62348e75 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java @@ -0,0 +1,96 @@ +package org.apache.commons.math4.ml.clustering; + +import org.apache.commons.math4.exception.ConvergenceException; +import org.apache.commons.math4.exception.util.LocalizedFormats; +import org.apache.commons.math4.ml.distance.DistanceMeasure; +import org.apache.commons.math4.stat.descriptive.moment.Variance; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class ClusterUtils { + private ClusterUtils() { + } + + public static ArrayList shuffle(Collection c, UniformRandomProvider random) { + ArrayList list = new ArrayList(c); + int size = list.size(); + for (int i = size; i > 1; --i) { + list.set(i - 1, list.set(random.nextInt(i), list.get(i - 1))); + } + return list; + } + + public static ArrayList shuffle(Collection points) { + return shuffle(points, RandomSource.create(RandomSource.MT_64)); + } + + /** + * Computes the centroid for a set of points. + * + * @param points the set of points + * @param dimension the point dimension + * @return the computed centroid for the set of points + */ + public static Clusterable centroidOf(final Collection points, final int dimension) { + final double[] centroid = new double[dimension]; + for (final T p : points) { + final double[] point = p.getPoint(); + for (int i = 0; i < centroid.length; i++) { + centroid[i] += point[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new DoublePoint(centroid); + } + + + /** + * Get a random point from the {@link Cluster} with the largest distance variance. + * + * @param clusters the {@link Cluster}s to search + * @param measure DistanceMeasure + * @param random Random generator + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + public static T getPointFromLargestVarianceCluster(final Collection> clusters, + final DistanceMeasure measure, + final UniformRandomProvider random) + throws ConvergenceException { + double maxVariance = Double.NEGATIVE_INFINITY; + Cluster selected = null; + for (final CentroidCluster cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + // compute the distance variance of the current cluster + final Clusterable center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(measure.compute(point.getPoint(), center.getPoint())); + } + final double variance = stat.getResult(); + + // select the cluster with the largest variance + if (variance > maxVariance) { + maxVariance = variance; + selected = cluster; + } + + } + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + } +} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java new file mode 100644 index 0000000000..e662fbe415 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java @@ -0,0 +1,281 @@ +package org.apache.commons.math4.ml.clustering; + +import org.apache.commons.math4.exception.ConvergenceException; +import org.apache.commons.math4.exception.MathIllegalArgumentException; +import org.apache.commons.math4.exception.NumberIsTooSmallException; +import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer; +import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer; +import org.apache.commons.math4.ml.distance.DistanceMeasure; +import org.apache.commons.math4.ml.distance.EuclideanDistance; +import org.apache.commons.math4.util.MathUtils; +import org.apache.commons.math4.util.Pair; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * A very fast clustering algorithm base on KMeans(Refer to Python sklearn.cluster.MiniBatchKMeans) + * Use a partial points in initialize cluster centers, and mini batch in iterations. + * It finish in few seconds when clustering millions of data, and has few differences between KMeans. + * See https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf + * + * @param Type of the points to cluster + */ +public class MiniBatchKMeansClusterer extends Clusterer { + /** + * The number of clusters. + */ + private final int k; + + /** + * The maximum number of iterations. + */ + private final int maxIterations; + + /** + * Batch data size in iteration. + */ + private final int batchSize; + /** + * Iteration count of initialize the centers. + */ + private final int initIterations; + /** + * Data size of batch to initialize the centers, default 3*k + */ + private final int initBatchSize; + /** + * Max iterate times when no improvement on step iterations. + */ + private final int maxNoImprovementTimes; + /** + * Random generator for choosing initial centers. + */ + private final UniformRandomProvider random; + + private final CentroidInitializer centroidInitializer; + + + /** + * Build a clusterer. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param batchSize the mini batch size for training iterations. + * @param initIterations the iterations to find out the best clusters centers. + * @param initBatchSize the mini batch size to initial the clusters centers. + * @param maxNoImprovementTimes the max iterations times when the square distance has no improvement. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + * may appear during algorithm iterations + * @param centroidInitializer the centroid initializer algorithm + */ + public MiniBatchKMeansClusterer(final int k, int maxIterations, final int batchSize, final int initIterations, + final int initBatchSize, final int maxNoImprovementTimes, + final DistanceMeasure measure, final UniformRandomProvider random, + final CentroidInitializer centroidInitializer) { + super(measure); + this.k = k; + this.maxIterations = maxIterations; + this.batchSize = batchSize; + this.initIterations = initIterations; + this.initBatchSize = initBatchSize; + this.maxNoImprovementTimes = maxNoImprovementTimes; + this.random = random; + this.centroidInitializer = centroidInitializer; + } + + /** + * Build a clusterer. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + * may appear during algorithm iterations + */ + public MiniBatchKMeansClusterer(int k, int maxIterations, DistanceMeasure measure, UniformRandomProvider random) { + this(k, maxIterations, 100, 3, 100, 10, + measure, random, new KMeansPlusPlusCentroidInitializer(measure, random)); + } + + + /** + * Build a clusterer. + * + * @param k the number of clusters to split the data into + */ + public MiniBatchKMeansClusterer(int k) { + this(k, 100, new EuclideanDistance(), RandomSource.create(RandomSource.MT_64)); + } + + /** + * Runs the MiniBatch K-means clustering algorithm. + * + * @param points the points to cluster + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + */ + @Override + public List> cluster(Collection points) throws MathIllegalArgumentException, ConvergenceException { + // sanity checks + MathUtils.checkNotNull(points); + + // number of clusters has to be smaller or equal the number of data points + if (points.size() < k) { + throw new NumberIsTooSmallException(points.size(), k, false); + } + + int pointSize = points.size(); + int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0); + int maxIterations = this.maxIterations * batchCount; + MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator(); + List> clusters = initialCenters(points); + for (int i = 0; i < maxIterations; i++) { + //清空上次的分类结果 + clearClustersPoints(clusters); + //随机抽样一批节点 + List batchPoints = randomMiniBatch(points, batchSize); + Pair>> pair = step(batchPoints, clusters); + double squareDistance = pair.getFirst(); + clusters = pair.getSecond(); + //评估改进情况 + if (evaluator.convergence(squareDistance, pointSize)) break; + } + clearClustersPoints(clusters); + //所有结点按质心分类 + for (T point : points) { + addToNearestCentroidCluster(point, clusters); + } + return clusters; + } + + /** + * clear clustered points + * + * @param clusters The clusters to clear + */ + private void clearClustersPoints(List> clusters) { + for (CentroidCluster cluster : clusters) { + cluster.getPoints().clear(); + } + } + + /** + * Mini batch iteration step + * + * @param batchPoints The mini batch points. + * @param clusters The cluster centers. + * @return Square distance of all the batch points to the nearest center, and newly clusters. + */ + protected Pair>> step( + List batchPoints, + List> clusters) { + //抽样结点归类 + for (T point : batchPoints) { + addToNearestCentroidCluster(point, clusters); + } + List> newClusters = new ArrayList>(clusters.size()); + //重算质心 + for (CentroidCluster cluster : clusters) { + Clusterable newCenter; + if (cluster.getPoints().isEmpty()) { + newCenter = new DoublePoint(ClusterUtils.getPointFromLargestVarianceCluster(clusters, this.getDistanceMeasure(), random).getPoint()); + } else { + newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); + } + newClusters.add(new CentroidCluster(newCenter)); + } + //重新归类抽样结点 + double squareDistance = 0.0; + for (T point : batchPoints) { + double d = addToNearestCentroidCluster(point, newClusters); + squareDistance += d * d; + } + return new Pair>>(squareDistance, newClusters); + } + + protected List randomMiniBatch(Collection points, int batchSize) { + return ClusterUtils.shuffle(new ArrayList(points), random).subList(0, batchSize); + } + + /** + * Initial cluster centers with multiply iterations, find out the best. + * + * @param points Points use to initial the cluster centers. + * @return Clusters with center + */ + protected List> initialCenters(Collection points) { + List validPoints = initBatchSize < points.size() ? randomMiniBatch(points, initBatchSize) : new ArrayList(points); + double nearestSquareDistance = Double.POSITIVE_INFINITY; + List> bestCenters = null; + for (int i = 0; i < initIterations; i++) { + List initialPoints = (initBatchSize < points.size()) ? randomMiniBatch(points, initBatchSize) : new ArrayList(points); + List> clusters = centroidInitializer.chooseCentroids(initialPoints, k); + Pair>> pair = step(validPoints, clusters); + double squareDistance = pair.getFirst(); + List> newClusters = pair.getSecond(); + //Find out a best centers that has the nearest total square distance. + if (squareDistance < nearestSquareDistance) { + nearestSquareDistance = squareDistance; + bestCenters = newClusters; + } + } + return bestCenters; + } + + /** + * Add a point to the cluster which the nearest center belong to. + * + * @param point The point to add. + * @param clusters The clusters to add to. + * @return The distance to nearest center. + */ + public double addToNearestCentroidCluster(T point, List> clusters) { + double minDistance = Double.POSITIVE_INFINITY; + CentroidCluster nearestCentroidCluster = null; + for (CentroidCluster centroidCluster : clusters) { + double distance = distance(point, centroidCluster.getCenter()); + if (distance < minDistance) { + minDistance = distance; + nearestCentroidCluster = centroidCluster; + } + } + assert nearestCentroidCluster != null; + nearestCentroidCluster.addPoint(point); + return minDistance; + } + + /** + * The Evaluator to evaluate whether the iteration should finish where square has no improvement for appointed times. + */ + class MiniBatchImprovementEvaluator { + private Double ewaInertia = null; + private double ewaInertiaMin = Double.POSITIVE_INFINITY; + private int noImprovementTimes = 0; + + protected boolean convergence(double squareDistance, int pointSize) { + double batchInertia = squareDistance / batchSize; + if (ewaInertia == null) { + ewaInertia = batchInertia; + } else { + double alpha = batchSize * 2.0 / (pointSize + 1); + alpha = Math.min(alpha, 1.0); + ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha; + } + if (ewaInertia < ewaInertiaMin) { + noImprovementTimes = 0; + ewaInertiaMin = ewaInertia; + } else { + noImprovementTimes++; + } + return noImprovementTimes >= maxNoImprovementTimes; + } + } +} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java new file mode 100644 index 0000000000..9a188edde2 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java @@ -0,0 +1,22 @@ +package org.apache.commons.math4.ml.clustering.initialization; + +import org.apache.commons.math4.ml.clustering.CentroidCluster; +import org.apache.commons.math4.ml.clustering.Clusterable; + +import java.util.Collection; +import java.util.List; + +/** + * Interface abstract the algorithm for clusterer to choose the initial centers. + */ +public interface CentroidInitializer { + + /** + * Choose the initial centers. + * + * @param points the points to choose the initial centers from + * @param k The number of clusters + * @return the initial centers + */ + List> chooseCentroids(final Collection points, final int k); +} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java new file mode 100644 index 0000000000..ef19179855 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java @@ -0,0 +1,169 @@ +package org.apache.commons.math4.ml.clustering.initialization; + +import org.apache.commons.math4.ml.clustering.CentroidCluster; +import org.apache.commons.math4.ml.clustering.Clusterable; +import org.apache.commons.math4.ml.distance.DistanceMeasure; +import org.apache.commons.rng.UniformRandomProvider; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * Use K-means++ to choose the initial centers. + * + * @see K-means++ (wikipedia) + */ +public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer { + private final DistanceMeasure measure; + private final UniformRandomProvider random; + + /** + * Build a K-means++ CentroidInitializer + * @param measure the distance measure to use + * @param random the random to use. + */ + public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider random) { + this.measure = measure; + this.random = random; + } + + /** + * Use K-means++ to choose the initial centers. + * + * @param points the points to choose the initial centers from + * @param k The number of clusters + * @return the initial centers + */ + @Override + public List> chooseCentroids(final Collection points, final int k) { + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List pointList = Collections.unmodifiableList(new ArrayList<>(points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. + final List> resultSet = new ArrayList<>(); + + // Choose one center uniformly at random from among the data points. + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + + resultSet.add(new CentroidCluster(firstPoint)); + + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = distance(firstPoint, pointList.get(i)); + minDistSquared[i] = d * d; + } + } + + while (resultSet.size() < k) { + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } + } + + // Add one new data point as a center. Each point x is chosen with + // probability proportional to D(x)2 + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } + } + } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new CentroidCluster(p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = distance(p, pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } + } + + return resultSet; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } +} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java new file mode 100644 index 0000000000..55b910d6c3 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java @@ -0,0 +1,41 @@ +package org.apache.commons.math4.ml.clustering.initialization; + +import org.apache.commons.math4.ml.clustering.CentroidCluster; +import org.apache.commons.math4.ml.clustering.ClusterUtils; +import org.apache.commons.math4.ml.clustering.Clusterable; +import org.apache.commons.rng.UniformRandomProvider; + +import java.util.*; + +/** + * Random choose the initial centers. + */ +public class RandomCentroidInitializer implements CentroidInitializer { + private final UniformRandomProvider random; + + /** + * Build a random RandomCentroidInitializer + * + * @param random the random to use. + */ + public RandomCentroidInitializer(final UniformRandomProvider random) { + this.random = random; + } + + /** + * Random choose the initial centers. + * + * @param points the points to choose the initial centers from + * @param k The number of clusters + * @return the initial centers + */ + @Override + public List> chooseCentroids(Collection points, int k) { + ArrayList list = ClusterUtils.shuffle(points, random); + List> result = new ArrayList<>(k); + for (int i = 0; i < k; i++) { + result.add(new CentroidCluster<>(list.get(i))); + } + return result; + } +} diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java new file mode 100644 index 0000000000..2425ae86e5 --- /dev/null +++ b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java @@ -0,0 +1,111 @@ +package org.apache.commons.math4.ml.clustering; + +import org.apache.commons.math4.linear.MatrixUtils; +import org.apache.commons.math4.linear.RealMatrix; +import org.apache.commons.math4.ml.distance.DistanceMeasure; +import org.apache.commons.math4.ml.distance.EuclideanDistance; +import org.apache.commons.math4.random.CorrelatedRandomVectorGenerator; +import org.apache.commons.math4.random.GaussianRandomGenerator; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +public class MiniBatchKMeansClustererTest { + private DistanceMeasure measure = new EuclideanDistance(); + + @Test + public void compareToKMeans() { + //Generate 4 cluster + List data = generateCircles(); + KMeansPlusPlusClusterer kMeans = new KMeansPlusPlusClusterer<>(4); + MiniBatchKMeansClusterer miniBatchKMeans = new MiniBatchKMeansClusterer<>(4); + List> kMeansClusters = kMeans.cluster(data); + List> miniBatchKMeansClusters = miniBatchKMeans.cluster(data); + Assert.assertEquals(4, kMeansClusters.size()); + Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size()); + int totalDiffCount = 0; + double totalCenterDistance = 0.0; + for (CentroidCluster kMeanCluster : kMeansClusters) { + CentroidCluster miniBatchCluster = predict(kMeanCluster.getCenter(), miniBatchKMeansClusters); + totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size()); + totalCenterDistance += measure.compute(kMeanCluster.getCenter().getPoint(), miniBatchCluster.getCenter().getPoint()); + } + double diffRatio = totalDiffCount * 1.0 / data.size(); + System.out.println(String.format("Centers total distance: %f, clusters total diff points: %d, diff ratio: %f%%", + totalCenterDistance, totalDiffCount, diffRatio * 100)); + // Difference ratio less than 2% + Assert.assertTrue(String.format("Different points ratio %f%%!", diffRatio * 100), diffRatio < 0.02); + } + + private CentroidCluster predict(Clusterable point, List> clusters) { + double minDistance = Double.POSITIVE_INFINITY; + CentroidCluster nearestCluster = null; + for (CentroidCluster cluster : clusters) { + double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint()); + if (distance < minDistance) { + minDistance = distance; + nearestCluster = cluster; + } + } + return nearestCluster; + } + + private List generateClusters() { + List data = new ArrayList<>(); + data.addAll(generateCluster(250, new double[]{-1.0, -1.0}, 0.5)); + data.addAll(generateCluster(250, new double[]{0.0, 0.0}, 0.5)); + data.addAll(generateCluster(250, new double[]{1.0, 1.0}, 0.5)); + data.addAll(generateCluster(250, new double[]{2.0, 2.0}, 0.5)); + return data; + } + + private List generateCluster(int size, double[] center, double radius) { + UniformRandomProvider rg = RandomSource.create(RandomSource.MT_64, 0); + GaussianRandomGenerator rawGenerator = new GaussianRandomGenerator(rg); + double[] standardDeviation = {0.5, 0.5}; + double c = standardDeviation[0] * standardDeviation[1] * radius; + double[][] cov = {{standardDeviation[0] * standardDeviation[0], c}, {c, standardDeviation[1] * standardDeviation[1]}}; + RealMatrix covariance = MatrixUtils.createRealMatrix(cov); + // Create a CorrelatedRandomVectorGenerator using "rawGenerator" for the components. + CorrelatedRandomVectorGenerator generator = + new CorrelatedRandomVectorGenerator(center, covariance, 1.0e-12 * covariance.getNorm(), rawGenerator); + // Use the generator to generate correlated vectors. + List data = new ArrayList(size); + for (int i = 0; i < size; i++) { + // Use the generator to generate vectors + double[] randomVector = generator.nextVector(); + data.add(new DoublePoint(randomVector)); + } + return data; + } + + private List generateCircles() { + List data = new ArrayList<>(); + Random random = new Random(0); + data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random)); + data.addAll(generateCircle(250, new double[]{0.0, 0.0}, 0.7, random)); + data.addAll(generateCircle(250, new double[]{1.0, 1.0}, 0.7, random)); + data.addAll(generateCircle(250, new double[]{2.0, 2.0}, 0.7, random)); + return data; + } + + List generateCircle(int count, double[] center, double radius, Random random) { + double x0 = center[0]; + double y0 = center[1]; + ArrayList list = new ArrayList(count); + for (int i = 0; i < count; i++) { + double ao = random.nextDouble() * 720 - 360; + double r = random.nextDouble() * radius * 2 - radius; + double x1 = x0 + r * Math.cos(ao * Math.PI / 180); + double y1 = y0 + r * Math.sin(ao * Math.PI / 180); + list.add(new DoublePoint(new double[]{x1, y1})); + } + return list; + } + +} From 5f0453c822a1962f2608d20f6cb3986bcbc463b7 Mon Sep 17 00:00:00 2001 From: CT Date: Sat, 18 Jan 2020 13:58:37 +0800 Subject: [PATCH 2/4] Fix some code style problem for MiniBatchKMeansClusterer --- .../math4/ml/clustering/ClusterUtils.java | 48 +++++++-- .../clustering/MiniBatchKMeansClusterer.java | 61 ++++++++--- .../RandomCentroidInitializer.java | 9 +- .../MiniBatchKMeansClustererTest.java | 101 ++++++------------ 4 files changed, 121 insertions(+), 98 deletions(-) diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java index 9c62348e75..e4ae9ce75f 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java @@ -3,29 +3,57 @@ import org.apache.commons.math4.exception.ConvergenceException; import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.math4.ml.distance.DistanceMeasure; +import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.math4.stat.descriptive.moment.Variance; import org.apache.commons.rng.UniformRandomProvider; -import org.apache.commons.rng.simple.RandomSource; -import java.util.ArrayList; import java.util.Collection; import java.util.List; +/** + * Common functions used in clustering + */ public class ClusterUtils { + /** + * Use only for static + */ private ClusterUtils() { } - public static ArrayList shuffle(Collection c, UniformRandomProvider random) { - ArrayList list = new ArrayList(c); - int size = list.size(); - for (int i = size; i > 1; --i) { - list.set(i - 1, list.set(random.nextInt(i), list.get(i - 1))); + public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance(); + + /** + * Predict which cluster is best for the point + * + * @param clusters cluster to predict into + * @param point point to predict + * @param measure distance measurer + * @param type of cluster point + * @return the cluster which has nearest center to the point + */ + public static CentroidCluster predict(List> clusters, Clusterable point, DistanceMeasure measure) { + double minDistance = Double.POSITIVE_INFINITY; + CentroidCluster nearestCluster = null; + for (CentroidCluster cluster : clusters) { + double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint()); + if (distance < minDistance) { + minDistance = distance; + nearestCluster = cluster; + } } - return list; + return nearestCluster; } - public static ArrayList shuffle(Collection points) { - return shuffle(points, RandomSource.create(RandomSource.MT_64)); + /** + * Predict which cluster is best for the point + * + * @param clusters cluster to predict into + * @param point point to predict + * @param type of cluster point + * @return the cluster which has nearest center to the point + */ + public static CentroidCluster predict(List> clusters, Clusterable point) { + return predict(clusters, point, DEFAULT_MEASURE); } /** diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java index e662fbe415..6a9fe97bd0 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java @@ -10,6 +10,7 @@ import org.apache.commons.math4.util.MathUtils; import org.apache.commons.math4.util.Pair; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.ListSampler; import org.apache.commons.rng.simple.RandomSource; import java.util.ArrayList; @@ -56,6 +57,9 @@ public class MiniBatchKMeansClusterer extends Clusterer 0 ? maxIterations : 100; this.batchSize = batchSize; this.initIterations = initIterations; this.initBatchSize = initBatchSize; @@ -100,7 +104,7 @@ public MiniBatchKMeansClusterer(final int k, int maxIterations, final int batchS * may appear during algorithm iterations */ public MiniBatchKMeansClusterer(int k, int maxIterations, DistanceMeasure measure, UniformRandomProvider random) { - this(k, maxIterations, 100, 3, 100, 10, + this(k, maxIterations, 100, 3, 100 * 3, 10, measure, random, new KMeansPlusPlusCentroidInitializer(measure, random)); } @@ -138,18 +142,19 @@ public List> cluster(Collection points) throws MathIllegal MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator(); List> clusters = initialCenters(points); for (int i = 0; i < maxIterations; i++) { - //清空上次的分类结果 + //Clear points in clusters clearClustersPoints(clusters); - //随机抽样一批节点 + //Random sampling a mini batch of points. List batchPoints = randomMiniBatch(points, batchSize); + // Processing the mini batch training step Pair>> pair = step(batchPoints, clusters); double squareDistance = pair.getFirst(); clusters = pair.getSecond(); - //评估改进情况 + // Evaluate the training can finished early. if (evaluator.convergence(squareDistance, pointSize)) break; } clearClustersPoints(clusters); - //所有结点按质心分类 + //Add every mini batch points to their nearest cluster. for (T point : points) { addToNearestCentroidCluster(point, clusters); } @@ -174,15 +179,15 @@ private void clearClustersPoints(List> clusters) { * @param clusters The cluster centers. * @return Square distance of all the batch points to the nearest center, and newly clusters. */ - protected Pair>> step( + private Pair>> step( List batchPoints, List> clusters) { - //抽样结点归类 + //Add every mini batch points to their nearest cluster. for (T point : batchPoints) { addToNearestCentroidCluster(point, clusters); } List> newClusters = new ArrayList>(clusters.size()); - //重算质心 + //Refresh then cluster centroid. for (CentroidCluster cluster : clusters) { Clusterable newCenter; if (cluster.getPoints().isEmpty()) { @@ -192,7 +197,7 @@ protected Pair>> step( } newClusters.add(new CentroidCluster(newCenter)); } - //重新归类抽样结点 + // Add every mini batch points to their nearest cluster again. double squareDistance = 0.0; for (T point : batchPoints) { double d = addToNearestCentroidCluster(point, newClusters); @@ -201,8 +206,21 @@ protected Pair>> step( return new Pair>>(squareDistance, newClusters); } - protected List randomMiniBatch(Collection points, int batchSize) { - return ClusterUtils.shuffle(new ArrayList(points), random).subList(0, batchSize); + /** + * Get a mini batch of points + * + * @param points all the points + * @param batchSize the mini batch size + * @return mini batch of all the points + */ + private List randomMiniBatch(Collection points, int batchSize) { + ArrayList list = new ArrayList(points); + ListSampler.shuffle(random, list); +// int size = list.size(); +// for (int i = size; i > 1; --i) { +// list.set(i - 1, list.set(random.nextInt(i), list.get(i - 1))); +// } + return list.subList(0, batchSize); } /** @@ -211,7 +229,7 @@ protected List randomMiniBatch(Collection points, int batchSize) { * @param points Points use to initial the cluster centers. * @return Clusters with center */ - protected List> initialCenters(Collection points) { + private List> initialCenters(Collection points) { List validPoints = initBatchSize < points.size() ? randomMiniBatch(points, initBatchSize) : new ArrayList(points); double nearestSquareDistance = Double.POSITIVE_INFINITY; List> bestCenters = null; @@ -237,7 +255,7 @@ protected List> initialCenters(Collection points) { * @param clusters The clusters to add to. * @return The distance to nearest center. */ - public double addToNearestCentroidCluster(T point, List> clusters) { + private double addToNearestCentroidCluster(T point, List> clusters) { double minDistance = Double.POSITIVE_INFINITY; CentroidCluster nearestCentroidCluster = null; for (CentroidCluster centroidCluster : clusters) { @@ -260,21 +278,34 @@ class MiniBatchImprovementEvaluator { private double ewaInertiaMin = Double.POSITIVE_INFINITY; private int noImprovementTimes = 0; - protected boolean convergence(double squareDistance, int pointSize) { + /** + * Evaluate whether the iteration should finish where square has no improvement for appointed times + * + * @param squareDistance the total square distance of the mini batch points to their nearest center. + * @param pointSize size of the the data points. + * @return true if no improvement for appointed times, otherwise false + */ + public boolean convergence(double squareDistance, int pointSize) { double batchInertia = squareDistance / batchSize; if (ewaInertia == null) { ewaInertia = batchInertia; } else { + // Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error, + // but java double does not have a div/0 error double alpha = batchSize * 2.0 / (pointSize + 1); alpha = Math.min(alpha, 1.0); ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha; } + + // Improved if (ewaInertia < ewaInertiaMin) { noImprovementTimes = 0; ewaInertiaMin = ewaInertia; } else { + // No improvement noImprovementTimes++; } + // Has no improvement continuous for many times return noImprovementTimes >= maxNoImprovementTimes; } } diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java index 55b910d6c3..36d515e99d 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java @@ -1,11 +1,13 @@ package org.apache.commons.math4.ml.clustering.initialization; import org.apache.commons.math4.ml.clustering.CentroidCluster; -import org.apache.commons.math4.ml.clustering.ClusterUtils; import org.apache.commons.math4.ml.clustering.Clusterable; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.ListSampler; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; /** * Random choose the initial centers. @@ -31,7 +33,8 @@ public RandomCentroidInitializer(final UniformRandomProvider random) { */ @Override public List> chooseCentroids(Collection points, int k) { - ArrayList list = ClusterUtils.shuffle(points, random); + ArrayList list = new ArrayList(points); + ListSampler.shuffle(random, list); List> result = new ArrayList<>(k); for (int i = 0; i < k; i++) { result.add(new CentroidCluster<>(list.get(i))); diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java index 2425ae86e5..42a9e9b2a3 100644 --- a/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java +++ b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java @@ -1,12 +1,7 @@ package org.apache.commons.math4.ml.clustering; -import org.apache.commons.math4.linear.MatrixUtils; -import org.apache.commons.math4.linear.RealMatrix; import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.EuclideanDistance; -import org.apache.commons.math4.random.CorrelatedRandomVectorGenerator; -import org.apache.commons.math4.random.GaussianRandomGenerator; -import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -18,79 +13,45 @@ public class MiniBatchKMeansClustererTest { private DistanceMeasure measure = new EuclideanDistance(); + /** + * Compare the result to KMeansPlusPlusClusterer + */ @Test - public void compareToKMeans() { + public void testCompareToKMeans() { //Generate 4 cluster - List data = generateCircles(); - KMeansPlusPlusClusterer kMeans = new KMeansPlusPlusClusterer<>(4); - MiniBatchKMeansClusterer miniBatchKMeans = new MiniBatchKMeansClusterer<>(4); - List> kMeansClusters = kMeans.cluster(data); - List> miniBatchKMeansClusters = miniBatchKMeans.cluster(data); - Assert.assertEquals(4, kMeansClusters.size()); - Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size()); - int totalDiffCount = 0; - double totalCenterDistance = 0.0; - for (CentroidCluster kMeanCluster : kMeansClusters) { - CentroidCluster miniBatchCluster = predict(kMeanCluster.getCenter(), miniBatchKMeansClusters); - totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size()); - totalCenterDistance += measure.compute(kMeanCluster.getCenter().getPoint(), miniBatchCluster.getCenter().getPoint()); - } - double diffRatio = totalDiffCount * 1.0 / data.size(); - System.out.println(String.format("Centers total distance: %f, clusters total diff points: %d, diff ratio: %f%%", - totalCenterDistance, totalDiffCount, diffRatio * 100)); - // Difference ratio less than 2% - Assert.assertTrue(String.format("Different points ratio %f%%!", diffRatio * 100), diffRatio < 0.02); - } - - private CentroidCluster predict(Clusterable point, List> clusters) { - double minDistance = Double.POSITIVE_INFINITY; - CentroidCluster nearestCluster = null; - for (CentroidCluster cluster : clusters) { - double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint()); - if (distance < minDistance) { - minDistance = distance; - nearestCluster = cluster; + int randomSeed = 0; + List data = generateCircles(randomSeed); + KMeansPlusPlusClusterer kMeans = new KMeansPlusPlusClusterer<>(4, -1, measure, + RandomSource.create(RandomSource.MT_64, randomSeed)); + MiniBatchKMeansClusterer miniBatchKMeans = new MiniBatchKMeansClusterer<>(4, -1, + measure, RandomSource.create(RandomSource.MT_64, randomSeed)); + for (int i = 0; i < 100; i++) { + List> kMeansClusters = kMeans.cluster(data); + List> miniBatchKMeansClusters = miniBatchKMeans.cluster(data); + Assert.assertEquals(4, kMeansClusters.size()); + Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size()); + int totalDiffCount = 0; + double totalCenterDistance = 0.0; + for (CentroidCluster kMeanCluster : kMeansClusters) { + CentroidCluster miniBatchCluster = ClusterUtils.predict(miniBatchKMeansClusters, kMeanCluster.getCenter()); + totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size()); + totalCenterDistance += measure.compute(kMeanCluster.getCenter().getPoint(), miniBatchCluster.getCenter().getPoint()); } + double diffRatio = totalDiffCount * 1.0 / data.size(); + System.out.println(String.format("Centers total distance: %f, clusters total diff points: %d, diff ratio: %f%%", + totalCenterDistance, totalDiffCount, diffRatio * 100)); + // Sometimes the +// Assert.assertTrue(String.format("Different points ratio %f%%!", diffRatio * 100), diffRatio < 0.03); } - return nearestCluster; - } - - private List generateClusters() { - List data = new ArrayList<>(); - data.addAll(generateCluster(250, new double[]{-1.0, -1.0}, 0.5)); - data.addAll(generateCluster(250, new double[]{0.0, 0.0}, 0.5)); - data.addAll(generateCluster(250, new double[]{1.0, 1.0}, 0.5)); - data.addAll(generateCluster(250, new double[]{2.0, 2.0}, 0.5)); - return data; - } - - private List generateCluster(int size, double[] center, double radius) { - UniformRandomProvider rg = RandomSource.create(RandomSource.MT_64, 0); - GaussianRandomGenerator rawGenerator = new GaussianRandomGenerator(rg); - double[] standardDeviation = {0.5, 0.5}; - double c = standardDeviation[0] * standardDeviation[1] * radius; - double[][] cov = {{standardDeviation[0] * standardDeviation[0], c}, {c, standardDeviation[1] * standardDeviation[1]}}; - RealMatrix covariance = MatrixUtils.createRealMatrix(cov); - // Create a CorrelatedRandomVectorGenerator using "rawGenerator" for the components. - CorrelatedRandomVectorGenerator generator = - new CorrelatedRandomVectorGenerator(center, covariance, 1.0e-12 * covariance.getNorm(), rawGenerator); - // Use the generator to generate correlated vectors. - List data = new ArrayList(size); - for (int i = 0; i < size; i++) { - // Use the generator to generate vectors - double[] randomVector = generator.nextVector(); - data.add(new DoublePoint(randomVector)); - } - return data; } - private List generateCircles() { + private List generateCircles(int randomSeed) { List data = new ArrayList<>(); - Random random = new Random(0); + Random random = new Random(randomSeed); data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random)); - data.addAll(generateCircle(250, new double[]{0.0, 0.0}, 0.7, random)); - data.addAll(generateCircle(250, new double[]{1.0, 1.0}, 0.7, random)); - data.addAll(generateCircle(250, new double[]{2.0, 2.0}, 0.7, random)); + data.addAll(generateCircle(260, new double[]{0.0, 0.0}, 0.7, random)); + data.addAll(generateCircle(270, new double[]{1.0, 1.0}, 0.7, random)); + data.addAll(generateCircle(280, new double[]{2.0, 2.0}, 0.7, random)); return data; } From 47a055e6264d084547854f9290461f020f2131cf Mon Sep 17 00:00:00 2001 From: CT Date: Fri, 21 Feb 2020 11:41:38 +0800 Subject: [PATCH 3/4] Remove duplicate code in MiniBatchKMeansClusterer and KMeansPlusPlusClusterer --- .../clustering/KMeansPlusPlusClusterer.java | 275 ++++++------------ .../clustering/MiniBatchKMeansClusterer.java | 23 +- .../evaluation/ClusterEvaluator.java | 17 +- .../initialization/CentroidInitializer.java | 2 +- .../KMeansPlusPlusCentroidInitializer.java | 2 +- .../RandomCentroidInitializer.java | 2 +- 6 files changed, 91 insertions(+), 230 deletions(-) diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java index 74699ffb07..48208b432c 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java @@ -19,13 +19,14 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.List; import org.apache.commons.math4.exception.ConvergenceException; import org.apache.commons.math4.exception.MathIllegalArgumentException; import org.apache.commons.math4.exception.NumberIsTooSmallException; import org.apache.commons.math4.exception.util.LocalizedFormats; +import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer; +import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer; import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.rng.simple.RandomSource; @@ -35,42 +36,67 @@ /** * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. + * * @param type of the points to cluster * @see K-means++ (wikipedia) * @since 3.2 */ public class KMeansPlusPlusClusterer extends Clusterer { - /** Strategies to use for replacing an empty cluster. */ + /** + * Strategies to use for replacing an empty cluster. + */ public enum EmptyClusterStrategy { - /** Split the cluster with largest distance variance. */ + /** + * Split the cluster with largest distance variance. + */ LARGEST_VARIANCE, - /** Split the cluster with largest number of points. */ + /** + * Split the cluster with largest number of points. + */ LARGEST_POINTS_NUMBER, - /** Create a cluster around the point farthest from its centroid. */ + /** + * Create a cluster around the point farthest from its centroid. + */ FARTHEST_POINT, - /** Generate an error. */ + /** + * Generate an error. + */ ERROR } - /** The number of clusters. */ + /** + * The number of clusters. + */ private final int k; - /** The maximum number of iterations. */ + /** + * The maximum number of iterations. + */ private final int maxIterations; - /** Random generator for choosing initial centers. */ + /** + * Random generator for choosing initial centers. + */ private final UniformRandomProvider random; - /** Selected strategy for empty clusters. */ + /** + * Selected strategy for empty clusters. + */ private final EmptyClusterStrategy emptyStrategy; - /** Build a clusterer. + /** + * Centroid initial algorithm + */ + private final CentroidInitializer centroidInitializer; + + /** + * Build a clusterer. *

* The default strategy for handling empty clusters that may appear during * algorithm iterations is to split the cluster with largest distance variance. @@ -83,45 +109,48 @@ public KMeansPlusPlusClusterer(final int k) { this(k, -1); } - /** Build a clusterer. + /** + * Build a clusterer. *

* The default strategy for handling empty clusters that may appear during * algorithm iterations is to split the cluster with largest distance variance. *

* The euclidean distance will be used as default distance measure. * - * @param k the number of clusters to split the data into + * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm for. - * If negative, no maximum will be used. + * If negative, no maximum will be used. */ public KMeansPlusPlusClusterer(final int k, final int maxIterations) { this(k, maxIterations, new EuclideanDistance()); } - /** Build a clusterer. + /** + * Build a clusterer. *

* The default strategy for handling empty clusters that may appear during * algorithm iterations is to split the cluster with largest distance variance. * - * @param k the number of clusters to split the data into + * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm for. - * If negative, no maximum will be used. - * @param measure the distance measure to use + * If negative, no maximum will be used. + * @param measure the distance measure to use */ public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { this(k, maxIterations, measure, RandomSource.create(RandomSource.MT_64)); } - /** Build a clusterer. + /** + * Build a clusterer. *

* The default strategy for handling empty clusters that may appear during * algorithm iterations is to split the cluster with largest distance variance. * - * @param k the number of clusters to split the data into + * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm for. - * If negative, no maximum will be used. - * @param measure the distance measure to use - * @param random random generator to use for choosing initial centers + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers */ public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure, @@ -129,29 +158,33 @@ public KMeansPlusPlusClusterer(final int k, final int maxIterations, this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE); } - /** Build a clusterer. + /** + * Build a clusterer. * - * @param k the number of clusters to split the data into + * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm for. - * If negative, no maximum will be used. - * @param measure the distance measure to use - * @param random random generator to use for choosing initial centers + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers * @param emptyStrategy strategy to use for handling empty clusters that - * may appear during algorithm iterations + * may appear during algorithm iterations */ public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure, final UniformRandomProvider random, final EmptyClusterStrategy emptyStrategy) { super(measure); - this.k = k; + this.k = k; this.maxIterations = maxIterations; - this.random = random; + this.random = random; this.emptyStrategy = emptyStrategy; + // It is a Common KMeans algorithm if centroidInitializer is not KMeansPlusPlus algorithm. + this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random); } /** * Return the number of clusters this instance will use. + * * @return the number of clusters */ public int getK() { @@ -160,6 +193,7 @@ public int getK() { /** * Returns the maximum number of iterations this instance will use. + * * @return the maximum number of iterations, or -1 if no maximum is set */ public int getMaxIterations() { @@ -168,6 +202,7 @@ public int getMaxIterations() { /** * Returns the random generator this instance will use. + * * @return the random generator */ public UniformRandomProvider getRandomGenerator() { @@ -176,6 +211,7 @@ public UniformRandomProvider getRandomGenerator() { /** * Returns the {@link EmptyClusterStrategy} used by this instance. + * * @return the {@link EmptyClusterStrategy} */ public EmptyClusterStrategy getEmptyClusterStrategy() { @@ -188,13 +224,13 @@ public EmptyClusterStrategy getEmptyClusterStrategy() { * @param points the points to cluster * @return a list of clusters containing the points * @throws MathIllegalArgumentException if the data points are null or the number - * of clusters is larger than the number of data points - * @throws ConvergenceException if an empty cluster is encountered and the - * {@link #emptyStrategy} is set to {@code ERROR} + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * {@link #emptyStrategy} is set to {@code ERROR} */ @Override public List> cluster(final Collection points) - throws MathIllegalArgumentException, ConvergenceException { + throws MathIllegalArgumentException, ConvergenceException { // sanity checks MathUtils.checkNotNull(points); @@ -205,7 +241,7 @@ public List> cluster(final Collection points) } // create the initial clusters - List> clusters = chooseInitialCenters(points); + List> clusters = centroidInitializer.selectCentroids(points, k); // create an array containing the latest assignment of a point to a cluster // no need to initialize the array, as it will be filled with the first assignment @@ -221,21 +257,21 @@ public List> cluster(final Collection points) final Clusterable newCenter; if (cluster.getPoints().isEmpty()) { switch (emptyStrategy) { - case LARGEST_VARIANCE : + case LARGEST_VARIANCE: newCenter = getPointFromLargestVarianceCluster(clusters); break; - case LARGEST_POINTS_NUMBER : + case LARGEST_POINTS_NUMBER: newCenter = getPointFromLargestNumberCluster(clusters); break; - case FARTHEST_POINT : + case FARTHEST_POINT: newCenter = getFarthestPoint(clusters); break; - default : + default: throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); } emptyCluster = true; } else { - newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); + newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); } newClusters.add(new CentroidCluster(newCenter)); } @@ -254,8 +290,8 @@ public List> cluster(final Collection points) /** * Adds the given points to the closest {@link Cluster}. * - * @param clusters the {@link Cluster}s to add the points to - * @param points the points to add to the given {@link Cluster}s + * @param clusters the {@link Cluster}s to add the points to + * @param points the points to add to the given {@link Cluster}s * @param assignments points assignments to clusters * @return the number of points assigned to different clusters as the iteration before */ @@ -278,131 +314,6 @@ private int assignPointsToClusters(final List> clusters, return assignedDifferently; } - /** - * Use K-means++ to choose the initial centers. - * - * @param points the points to choose the initial centers from - * @return the initial centers - */ - private List> chooseInitialCenters(final Collection points) { - - // Convert to list for indexed access. Make it unmodifiable, since removal of items - // would screw up the logic of this method. - final List pointList = Collections.unmodifiableList(new ArrayList<> (points)); - - // The number of points in the list. - final int numPoints = pointList.size(); - - // Set the corresponding element in this array to indicate when - // elements of pointList are no longer available. - final boolean[] taken = new boolean[numPoints]; - - // The resulting list of initial centers. - final List> resultSet = new ArrayList<>(); - - // Choose one center uniformly at random from among the data points. - final int firstPointIndex = random.nextInt(numPoints); - - final T firstPoint = pointList.get(firstPointIndex); - - resultSet.add(new CentroidCluster(firstPoint)); - - // Must mark it as taken - taken[firstPointIndex] = true; - - // To keep track of the minimum distance squared of elements of - // pointList to elements of resultSet. - final double[] minDistSquared = new double[numPoints]; - - // Initialize the elements. Since the only point in resultSet is firstPoint, - // this is very easy. - for (int i = 0; i < numPoints; i++) { - if (i != firstPointIndex) { // That point isn't considered - double d = distance(firstPoint, pointList.get(i)); - minDistSquared[i] = d*d; - } - } - - while (resultSet.size() < k) { - - // Sum up the squared distances for the points in pointList not - // already taken. - double distSqSum = 0.0; - - for (int i = 0; i < numPoints; i++) { - if (!taken[i]) { - distSqSum += minDistSquared[i]; - } - } - - // Add one new data point as a center. Each point x is chosen with - // probability proportional to D(x)2 - final double r = random.nextDouble() * distSqSum; - - // The index of the next point to be added to the resultSet. - int nextPointIndex = -1; - - // Sum through the squared min distances again, stopping when - // sum >= r. - double sum = 0.0; - for (int i = 0; i < numPoints; i++) { - if (!taken[i]) { - sum += minDistSquared[i]; - if (sum >= r) { - nextPointIndex = i; - break; - } - } - } - - // If it's not set to >= 0, the point wasn't found in the previous - // for loop, probably because distances are extremely small. Just pick - // the last available point. - if (nextPointIndex == -1) { - for (int i = numPoints - 1; i >= 0; i--) { - if (!taken[i]) { - nextPointIndex = i; - break; - } - } - } - - // We found one. - if (nextPointIndex >= 0) { - - final T p = pointList.get(nextPointIndex); - - resultSet.add(new CentroidCluster (p)); - - // Mark it as taken. - taken[nextPointIndex] = true; - - if (resultSet.size() < k) { - // Now update elements of minDistSquared. We only have to compute - // the distance to the new center to do this. - for (int j = 0; j < numPoints; j++) { - // Only have to worry about the points still not taken. - if (!taken[j]) { - double d = distance(p, pointList.get(j)); - double d2 = d * d; - if (d2 < minDistSquared[j]) { - minDistSquared[j] = d2; - } - } - } - } - - } else { - // None found -- - // Break from the while loop to prevent - // an infinite loop. - break; - } - } - - return resultSet; - } - /** * Get a random point from the {@link Cluster} with the largest distance variance. * @@ -502,9 +413,9 @@ private T getFarthestPoint(final Collection> clusters) throws for (int i = 0; i < points.size(); ++i) { final double distance = distance(points.get(i), center); if (distance > maxDistance) { - maxDistance = distance; + maxDistance = distance; selectedCluster = cluster; - selectedPoint = i; + selectedPoint = i; } } @@ -523,7 +434,7 @@ private T getFarthestPoint(final Collection> clusters) throws * Returns the nearest {@link Cluster} to the given point * * @param clusters the {@link Cluster}s to search - * @param point the point to find the nearest {@link Cluster} for + * @param point the point to find the nearest {@link Cluster} for * @return the index of the nearest {@link Cluster} to the given point */ private int getNearestCluster(final Collection> clusters, final T point) { @@ -540,26 +451,4 @@ private int getNearestCluster(final Collection> clusters, fin } return minCluster; } - - /** - * Computes the centroid for a set of points. - * - * @param points the set of points - * @param dimension the point dimension - * @return the computed centroid for the set of points - */ - private Clusterable centroidOf(final Collection points, final int dimension) { - final double[] centroid = new double[dimension]; - for (final T p : points) { - final double[] point = p.getPoint(); - for (int i = 0; i < centroid.length; i++) { - centroid[i] += point[i]; - } - } - for (int i = 0; i < centroid.length; i++) { - centroid[i] /= points.size(); - } - return new DoublePoint(centroid); - } - } diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java index 6a9fe97bd0..9c848d61c2 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java @@ -6,12 +6,10 @@ import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer; import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer; import org.apache.commons.math4.ml.distance.DistanceMeasure; -import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.math4.util.MathUtils; import org.apache.commons.math4.util.Pair; import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.sampling.ListSampler; -import org.apache.commons.rng.simple.RandomSource; import java.util.ArrayList; import java.util.Collection; @@ -94,7 +92,7 @@ public MiniBatchKMeansClusterer(final int k, int maxIterations, final int batchS } /** - * Build a clusterer. + * Build a clusterer * * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm for. @@ -108,16 +106,6 @@ public MiniBatchKMeansClusterer(int k, int maxIterations, DistanceMeasure measur measure, random, new KMeansPlusPlusCentroidInitializer(measure, random)); } - - /** - * Build a clusterer. - * - * @param k the number of clusters to split the data into - */ - public MiniBatchKMeansClusterer(int k) { - this(k, 100, new EuclideanDistance(), RandomSource.create(RandomSource.MT_64)); - } - /** * Runs the MiniBatch K-means clustering algorithm. * @@ -137,8 +125,9 @@ public List> cluster(Collection points) throws MathIllegal } int pointSize = points.size(); + int batchSize = this.batchSize; int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0); - int maxIterations = this.maxIterations * batchCount; + int maxIterations = (this.maxIterations <= 0) ? Integer.MAX_VALUE : (this.maxIterations * batchCount); MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator(); List> clusters = initialCenters(points); for (int i = 0; i < maxIterations; i++) { @@ -216,10 +205,6 @@ private Pair>> step( private List randomMiniBatch(Collection points, int batchSize) { ArrayList list = new ArrayList(points); ListSampler.shuffle(random, list); -// int size = list.size(); -// for (int i = size; i > 1; --i) { -// list.set(i - 1, list.set(random.nextInt(i), list.get(i - 1))); -// } return list.subList(0, batchSize); } @@ -235,7 +220,7 @@ private List> initialCenters(Collection points) { List> bestCenters = null; for (int i = 0; i < initIterations; i++) { List initialPoints = (initBatchSize < points.size()) ? randomMiniBatch(points, initBatchSize) : new ArrayList(points); - List> clusters = centroidInitializer.chooseCentroids(initialPoints, k); + List> clusters = centroidInitializer.selectCentroids(initialPoints, k); Pair>> pair = step(validPoints, clusters); double squareDistance = pair.getFirst(); List> newClusters = pair.getSecond(); diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java index 5f364c040e..314b6f8aac 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java @@ -19,10 +19,7 @@ import java.util.List; -import org.apache.commons.math4.ml.clustering.CentroidCluster; -import org.apache.commons.math4.ml.clustering.Cluster; -import org.apache.commons.math4.ml.clustering.Clusterable; -import org.apache.commons.math4.ml.clustering.DoublePoint; +import org.apache.commons.math4.ml.clustering.*; import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.EuclideanDistance; @@ -106,17 +103,7 @@ protected Clusterable centroidOf(final Cluster cluster) { } final int dimension = points.get(0).getPoint().length; - final double[] centroid = new double[dimension]; - for (final T p : points) { - final double[] point = p.getPoint(); - for (int i = 0; i < centroid.length; i++) { - centroid[i] += point[i]; - } - } - for (int i = 0; i < centroid.length; i++) { - centroid[i] /= points.size(); - } - return new DoublePoint(centroid); + return ClusterUtils.centroidOf(points,dimension); } } diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java index 9a188edde2..4adb67406c 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java @@ -18,5 +18,5 @@ public interface CentroidInitializer { * @param k The number of clusters * @return the initial centers */ - List> chooseCentroids(final Collection points, final int k); + List> selectCentroids(final Collection points, final int k); } diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java index ef19179855..bc94987979 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java @@ -37,7 +37,7 @@ public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final Un * @return the initial centers */ @Override - public List> chooseCentroids(final Collection points, final int k) { + public List> selectCentroids(final Collection points, final int k) { // Convert to list for indexed access. Make it unmodifiable, since removal of items // would screw up the logic of this method. final List pointList = Collections.unmodifiableList(new ArrayList<>(points)); diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java index 36d515e99d..723876711b 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java @@ -32,7 +32,7 @@ public RandomCentroidInitializer(final UniformRandomProvider random) { * @return the initial centers */ @Override - public List> chooseCentroids(Collection points, int k) { + public List> selectCentroids(Collection points, int k) { ArrayList list = new ArrayList(points); ListSampler.shuffle(random, list); List> result = new ArrayList<>(k); From 6e15a9ff28ba48adf1120ce0d36634d49ed0f9a7 Mon Sep 17 00:00:00 2001 From: CT Date: Fri, 21 Feb 2020 11:42:19 +0800 Subject: [PATCH 4/4] Compare MiniBatchKMeansClusterer to KMeansPlusPlusClusterer by SumOfClusterVariances. --- .../MiniBatchKMeansClustererTest.java | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java index 42a9e9b2a3..980a62cec0 100644 --- a/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java +++ b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java @@ -1,5 +1,7 @@ package org.apache.commons.math4.ml.clustering; +import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator; +import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances; import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.rng.simple.RandomSource; @@ -31,17 +33,19 @@ public void testCompareToKMeans() { Assert.assertEquals(4, kMeansClusters.size()); Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size()); int totalDiffCount = 0; - double totalCenterDistance = 0.0; for (CentroidCluster kMeanCluster : kMeansClusters) { CentroidCluster miniBatchCluster = ClusterUtils.predict(miniBatchKMeansClusters, kMeanCluster.getCenter()); totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size()); - totalCenterDistance += measure.compute(kMeanCluster.getCenter().getPoint(), miniBatchCluster.getCenter().getPoint()); } - double diffRatio = totalDiffCount * 1.0 / data.size(); - System.out.println(String.format("Centers total distance: %f, clusters total diff points: %d, diff ratio: %f%%", - totalCenterDistance, totalDiffCount, diffRatio * 100)); - // Sometimes the -// Assert.assertTrue(String.format("Different points ratio %f%%!", diffRatio * 100), diffRatio < 0.03); + ClusterEvaluator clusterEvaluator = new SumOfClusterVariances<>(measure); + double kMeansScore = clusterEvaluator.score(kMeansClusters); + double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters); + double diffPointsRatio = totalDiffCount * 1.0 / data.size(); + double scoreDiffRatio = (miniBatchKMeansScore - kMeansScore) / + kMeansScore; + // MiniBatchKMeansClusterer has few score differences between KMeansClusterer + Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%\"", scoreDiffRatio * 100, diffPointsRatio * 100), + scoreDiffRatio < 0.1); } }