1- // Copyright (c) Microsoft Corporation. All rights reserved.
2- // Licensed under the MIT License.
1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
33
44#include < mpi.h>
55#include < thread>
@@ -156,6 +156,8 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
156156 for (int k = 0 ; k < args._K ; k++) avgCount += args.counts [k];
157157 avgCount /= args._K ;
158158
159+ std::vector<float > dist_total (args._K * args._T , 0 );
160+
159161 auto func = [&](int tid)
160162 {
161163 SizeType istart = first + tid * subsize;
@@ -165,6 +167,7 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
165167 float * inewCenters = args.newCenters + tid * args._K * args._D ;
166168 SizeType* iclusterIdx = args.clusterIdx + tid * args._K ;
167169 float * iclusterDist = args.clusterDist + tid * args._K ;
170+ float * idist_total = dist_total.data () + tid * args._K ;
168171 float idist = 0 ;
169172 std::vector<SPTAG::NodeDistPair> centerDist (args._K , SPTAG::NodeDistPair ());
170173 for (SizeType i = istart; i < iend; i++) {
@@ -184,6 +187,7 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
184187 inewCounts[centerDist[k].node ]++;
185188 inewWeightedCounts[centerDist[k].node ] += weights[indices[i]];
186189 idist += centerDist[k].distance ;
190+ idist_total[centerDist[k].node ] += centerDist[k].distance ;
187191
188192 if (updateCenters) {
189193 const T* v = (const T*)data[indices[i]];
@@ -217,9 +221,15 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
217221 for (int k = 0 ; k < args._K ; k++) {
218222 args.newCounts [k] += args.newCounts [i*args._K + k];
219223 args.newWeightedCounts [k] += args.newWeightedCounts [i*args._K + k];
224+ dist_total[k] += dist_total[i * args._K + k];
220225 }
221226 }
222227
228+ LOG (Helper::LogLevel::LL_Info, " start printing dist_total\n " );
229+ for (int k = 0 ; k < args._K ; k++) {
230+ LOG (Helper::LogLevel::LL_Info, " cluster %d: dist_total:%f \n " , k, dist_total[k]);
231+ }
232+
223233 if (updateCenters) {
224234 for (int i = 1 ; i < args._T ; i++) {
225235 float * currCenter = args.newCenters + i*args._K *args._D ;
@@ -768,7 +778,9 @@ void ProcessWithoutMPI() {
768778 COMMON::KmeansArgs<T> args (options.m_clusterNum , vectors->Dimension (), vectors->Count (), options.m_threadNum , options.m_distMethod );
769779 COMMON::Dataset<LabelType> label (vectors->Count (), options.m_clusterassign , vectors->Count (), vectors->Count ());
770780 std::vector<SizeType> localindices (data.R (), 0 );
771- for (SizeType i = 0 ; i < data.R (); i++) localindices[i] = i;
781+ for (SizeType i = 0 ; i < data.R (); i++) {
782+ localindices[i] = i;
783+ }
772784 args.ClearCounts ();
773785
774786 unsigned long long totalCount;
0 commit comments