Skip to content

Commit de5b7f8

Browse files
authored
add logger for total distance (#351)
1 parent dda1180 commit de5b7f8

File tree

1 file changed

+15
-3
lines changed
  • AnnService/src/BalancedDataPartition

1 file changed

+15
-3
lines changed

AnnService/src/BalancedDataPartition/main.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// Copyright (c) Microsoft Corporation. All rights reserved.
2-
// Licensed under the MIT License.
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
33

44
#include <mpi.h>
55
#include <thread>
@@ -156,6 +156,8 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
156156
for (int k = 0; k < args._K; k++) avgCount += args.counts[k];
157157
avgCount /= args._K;
158158

159+
std::vector<float> dist_total(args._K * args._T, 0);
160+
159161
auto func = [&](int tid)
160162
{
161163
SizeType istart = first + tid * subsize;
@@ -165,6 +167,7 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
165167
float* inewCenters = args.newCenters + tid * args._K * args._D;
166168
SizeType* iclusterIdx = args.clusterIdx + tid * args._K;
167169
float* iclusterDist = args.clusterDist + tid * args._K;
170+
float* idist_total = dist_total.data() + tid * args._K;
168171
float idist = 0;
169172
std::vector<SPTAG::NodeDistPair> centerDist(args._K, SPTAG::NodeDistPair());
170173
for (SizeType i = istart; i < iend; i++) {
@@ -184,6 +187,7 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
184187
inewCounts[centerDist[k].node]++;
185188
inewWeightedCounts[centerDist[k].node] += weights[indices[i]];
186189
idist += centerDist[k].distance;
190+
idist_total[centerDist[k].node] += centerDist[k].distance;
187191

188192
if (updateCenters) {
189193
const T* v = (const T*)data[indices[i]];
@@ -217,9 +221,15 @@ inline float MultipleClustersAssign(const COMMON::Dataset<T>& data,
217221
for (int k = 0; k < args._K; k++) {
218222
args.newCounts[k] += args.newCounts[i*args._K + k];
219223
args.newWeightedCounts[k] += args.newWeightedCounts[i*args._K + k];
224+
dist_total[k] += dist_total[i * args._K + k];
220225
}
221226
}
222227

228+
LOG(Helper::LogLevel::LL_Info, "start printing dist_total\n");
229+
for (int k = 0; k < args._K; k++) {
230+
LOG(Helper::LogLevel::LL_Info, "cluster %d: dist_total:%f \n", k, dist_total[k]);
231+
}
232+
223233
if (updateCenters) {
224234
for (int i = 1; i < args._T; i++) {
225235
float* currCenter = args.newCenters + i*args._K*args._D;
@@ -768,7 +778,9 @@ void ProcessWithoutMPI() {
768778
COMMON::KmeansArgs<T> args(options.m_clusterNum, vectors->Dimension(), vectors->Count(), options.m_threadNum, options.m_distMethod);
769779
COMMON::Dataset<LabelType> label(vectors->Count(), options.m_clusterassign, vectors->Count(), vectors->Count());
770780
std::vector<SizeType> localindices(data.R(), 0);
771-
for (SizeType i = 0; i < data.R(); i++) localindices[i] = i;
781+
for (SizeType i = 0; i < data.R(); i++) {
782+
localindices[i] = i;
783+
}
772784
args.ClearCounts();
773785

774786
unsigned long long totalCount;

0 commit comments

Comments
 (0)