diff --git a/src/analysis/processing/qgsalgorithmkmeansclustering.cpp b/src/analysis/processing/qgsalgorithmkmeansclustering.cpp index 134652a0434e..352671ecaf91 100644 --- a/src/analysis/processing/qgsalgorithmkmeansclustering.cpp +++ b/src/analysis/processing/qgsalgorithmkmeansclustering.cpp @@ -17,6 +17,7 @@ #include "qgsalgorithmkmeansclustering.h" #include +#include ///@cond PRIVATE @@ -52,6 +53,11 @@ void QgsKMeansClusteringAlgorithm::initAlgorithm( const QVariantMap & ) addParameter( new QgsProcessingParameterFeatureSource( QStringLiteral( "INPUT" ), QObject::tr( "Input layer" ), QList() << static_cast( Qgis::ProcessingSourceType::VectorAnyGeometry ) ) ); addParameter( new QgsProcessingParameterNumber( QStringLiteral( "CLUSTERS" ), QObject::tr( "Number of clusters" ), Qgis::ProcessingNumberParameterType::Integer, 5, false, 1 ) ); + QStringList initializationMethods; + initializationMethods << QObject::tr( "Farthest points" ) + << QObject::tr( "K-means++" ); + addParameter( new QgsProcessingParameterEnum( QStringLiteral( "METHOD" ), QObject::tr( "Method" ), initializationMethods, false, 0, false ) ); + auto fieldNameParam = std::make_unique( QStringLiteral( "FIELD_NAME" ), QObject::tr( "Cluster field name" ), QStringLiteral( "CLUSTER_ID" ) ); fieldNameParam->setFlags( fieldNameParam->flags() | Qgis::ProcessingParameterFlag::Advanced ); addParameter( fieldNameParam.release() ); @@ -65,7 +71,10 @@ void QgsKMeansClusteringAlgorithm::initAlgorithm( const QVariantMap & ) QString QgsKMeansClusteringAlgorithm::shortHelpString() const { return QObject::tr( "Calculates the 2D distance based k-means cluster number for each input feature.\n\n" - "If input geometries are lines or polygons, the clustering is based on the centroid of the feature." ); + "If input geometries are lines or polygons, the clustering is based on the centroid of the feature.\\n" + "References:\\n" + "Arthur, David & Vassilvitskii, Sergei. (2007). K-Means++: The Advantages of Careful Seeding. Proc. of the Annu. ACM-SIAM Symp. on Discrete Algorithms. 8.\\n" + "Bhattacharya, Anup & Eube, Jan & Röglin, Heiko & Schmidt, Melanie. (2019). Noisy, Greedy and Not So Greedy k-means++"); } QgsKMeansClusteringAlgorithm *QgsKMeansClusteringAlgorithm::createInstance() const @@ -80,6 +89,7 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p throw QgsProcessingException( invalidSourceError( parameters, QStringLiteral( "INPUT" ) ) ); int k = parameterAsInt( parameters, QStringLiteral( "CLUSTERS" ), context ); + int initializationMethod = parameterAsInt( parameters, QStringLiteral( "METHOD" ), context ); QgsFields outputFields = source->fields(); QgsFields newFields; @@ -148,8 +158,17 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p // cluster centers std::vector centers( k ); - - initClusters( clusterFeatures, centers, k, feedback ); + switch ( initializationMethod ) + { + case 0: // farthest points + initClustersFarthestPoints( clusterFeatures, centers, k, feedback ); + break; + case 1: // k-means++ + initClustersPlusPlus( clusterFeatures, centers, k, feedback ); + break; + default: + break; + } calculateKMeans( clusterFeatures, centers, k, feedback ); } @@ -198,7 +217,7 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p // ported from https://github.com/postgis/postgis/blob/svn-trunk/liblwgeom/lwkmeans.c -void QgsKMeansClusteringAlgorithm::initClusters( std::vector &points, std::vector ¢ers, const int k, QgsProcessingFeedback *feedback ) +void QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( std::vector &points, std::vector ¢ers, const int k, QgsProcessingFeedback *feedback ) { const std::size_t n = points.size(); if ( n == 0 ) @@ -298,6 +317,140 @@ void QgsKMeansClusteringAlgorithm::initClusters( std::vector &points, s } } +void QgsKMeansClusteringAlgorithm::initClustersPlusPlus( std::vector &points, std::vector ¢ers, const int k, QgsProcessingFeedback *feedback ) +{ + const std::size_t n = points.size(); + if ( n == 0 ) + return; + + if ( n == 1 ) + { + for ( int i = 0; i < k; i++ ) + centers[i] = points[0].point; + return; + } + + // randomly select the first point + std::random_device rd; + std::mt19937 gen( rd() ); + std::uniform_int_distribution distrib( 0, n - 1 ); + + int p1 = distrib( gen ); + centers[0] = points[p1].point; + + // calculate distances and total error (sum of distances of points to center) + std::vector distances( n ); + double totalError = 0; + long duplicateCount = 1; + for ( size_t i = 0; i < n; i++ ) + { + double distance = points[i].point.sqrDist( centers[0] ); + distances[i] = distance; + totalError += distance; + if ( qgsDoubleNear( distance, 0 ) ) + { + duplicateCount++; + } + } + if ( feedback && duplicateCount > 1 ) + { + feedback->pushInfo( QObject::tr( "There are at least %n duplicate input(s), the number of output clusters may be less than was requested", nullptr, duplicateCount ) ); + } + + // greedy kmeans++ + // test not only one center but L possible centers + // chosen independently according to the same probability distribution), and then among these L + // centers, the one that decreases the k-means cost the most is chosen + // Bhattacharya, Anup & Eube, Jan & Röglin, Heiko & Schmidt, Melanie. (2019). Noisy, greedy and Not So greedy k-means++ + unsigned int numCandidateCenters = 2 + std::floor( std::log( k ) ); + std::vector randomNumbers( numCandidateCenters ); + std::vector candidateCenters( numCandidateCenters ); + + std::uniform_real_distribution dis( 0.0, 1.0 ); + for ( int i = 1; i < k; i++ ) + { + // sampling with probability proportional to the squared distance to the closest existing center + for ( unsigned int j = 0; j < numCandidateCenters; j++ ) + { + randomNumbers[j] = dis( gen ) * totalError; + } + + // cumulative sum, keep distances for later use + std::vector cumSum = distances; + for ( size_t j = 1; j < n; j++ ) + { + cumSum[j] += cumSum[j - 1]; + } + + // binary search for the index of the first element greater than or equal to random numbers + for ( unsigned int j = 0; j < numCandidateCenters; j++ ) + { + size_t low = 0; + size_t high = n - 1; + + while ( low <= high ) + { + size_t mid = low + ( high - low ) / 2; + if ( cumSum[mid] < randomNumbers[j] ) + { + low = mid + 1; + } + else + { + // size_t cannot be negative + if ( mid == 0 ) + break; + + high = mid - 1; + } + } + // clip candidate center to the number of points + if ( low >= n ) + { + low = n - 1; + } + candidateCenters[j] = low; + } + + std::vector> distancesCandidateCenters( numCandidateCenters, std::vector( n ) );; + + // store distances between previous and new candidate center, error and best candidate index + double currentError = 0; + double lowestError = std::numeric_limits::max(); + unsigned int bestCandidateIndex = 0; + for ( unsigned int j = 0; j < numCandidateCenters; j++ ) + { + for ( size_t z = 0; z < n; z++ ) + { + // distance to candidate center + double distance = points[candidateCenters[j]].point.sqrDist( points[z].point ); + // if distance to previous center is less than the current distance, use that + if ( distance > distances[z] ) + { + distance = distances[z]; + } + distancesCandidateCenters[j][z] = distance; + currentError += distance; + } + if ( lowestError > currentError ) + { + lowestError = currentError; + bestCandidateIndex = j; + } + } + + // update distances with the best candidate center values + for ( size_t j = 0; j < n; j++ ) + { + distances[j] = distancesCandidateCenters[bestCandidateIndex][j]; + } + // store the best candidate center + centers[i] = points[candidateCenters[bestCandidateIndex]].point; + // update error + totalError = lowestError; + } +} + // ported from https://github.com/postgis/postgis/blob/svn-trunk/liblwgeom/lwkmeans.c void QgsKMeansClusteringAlgorithm::calculateKMeans( std::vector &objs, std::vector ¢ers, int k, QgsProcessingFeedback *feedback ) diff --git a/src/analysis/processing/qgsalgorithmkmeansclustering.h b/src/analysis/processing/qgsalgorithmkmeansclustering.h index b5c43602e7a8..b1ff8d56a935 100644 --- a/src/analysis/processing/qgsalgorithmkmeansclustering.h +++ b/src/analysis/processing/qgsalgorithmkmeansclustering.h @@ -57,7 +57,8 @@ class ANALYSIS_EXPORT QgsKMeansClusteringAlgorithm : public QgsProcessingAlgorit int cluster = -1; }; - static void initClusters( std::vector &points, std::vector ¢ers, int k, QgsProcessingFeedback *feedback ); + static void initClustersFarthestPoints( std::vector &points, std::vector ¢ers, int k, QgsProcessingFeedback *feedback ); + static void initClustersPlusPlus( std::vector &points, std::vector ¢ers, int k, QgsProcessingFeedback *feedback ); static void calculateKMeans( std::vector &points, std::vector ¢ers, int k, QgsProcessingFeedback *feedback ); static void findNearest( std::vector &points, const std::vector ¢ers, int k, bool &changed ); static void updateMeans( const std::vector &points, std::vector ¢ers, std::vector &weights, int k ); diff --git a/tests/src/analysis/testqgsprocessingalgspt1.cpp b/tests/src/analysis/testqgsprocessingalgspt1.cpp index 64ae2a8539af..97af3e73e16c 100644 --- a/tests/src/analysis/testqgsprocessingalgspt1.cpp +++ b/tests/src/analysis/testqgsprocessingalgspt1.cpp @@ -1016,44 +1016,65 @@ void TestQgsProcessingAlgsPt1::kmeansCluster() // no features, no crash int k = 2; - QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr ); + // farthest points + QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr ); + QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); + // kmeans++ + QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr ); QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); // features < clusters - features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 5 ) ) ); - QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr ); + // farthest points + features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 1 ) ) ); + QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr ); + QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); + QCOMPARE( features[0].cluster, 0 ); + // kmeans++ + QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr ); QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); QCOMPARE( features[0].cluster, 0 ); // features == clusters - features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 11, 5 ) ) ); - QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr ); + features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 3, 1 ) ) ); + // farthest points + QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr ); QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); QCOMPARE( features[0].cluster, 1 ); QCOMPARE( features[1].cluster, 0 ); + // kmeans++ + QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr ); + QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); + QVERIFY( features[0].cluster != features[1].cluster ); // features > clusters - features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 13, 3 ) ) ); - features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 13, 13 ) ) ); - features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 23, 6 ) ) ); + features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 2, 8 ) ) ); + features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 10 ) ) ); + features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 3, 10 ) ) ); k = 2; - QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr ); + // farthest points + QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr ); QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); QCOMPARE( features[0].cluster, 1 ); QCOMPARE( features[1].cluster, 1 ); QCOMPARE( features[2].cluster, 0 ); QCOMPARE( features[3].cluster, 0 ); QCOMPARE( features[4].cluster, 0 ); + // kmeans++ + QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr ); + QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); + QCOMPARE( features[0].cluster, features[1].cluster ); + QCOMPARE( features[2].cluster, features[3].cluster ); + QCOMPARE( features[4].cluster, features[3].cluster ); // repeat above, with 3 clusters k = 3; centers.resize( 3 ); - QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr ); + QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr ); QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr ); QCOMPARE( features[0].cluster, 1 ); - QCOMPARE( features[1].cluster, 2 ); + QCOMPARE( features[1].cluster, 1 ); QCOMPARE( features[2].cluster, 2 ); - QCOMPARE( features[3].cluster, 2 ); + QCOMPARE( features[3].cluster, 0 ); QCOMPARE( features[4].cluster, 0 ); // with identical points