Skip to content

Commit

Permalink
more work
Browse files Browse the repository at this point in the history
  • Loading branch information
hosseinmoein committed Nov 4, 2024
1 parent 89c17cd commit 101e81c
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 139 deletions.
4 changes: 3 additions & 1 deletion include/DataFrame/DataFrameMLVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,9 @@ struct FastHierVisitor {

generic_linkage_vector<method_codes_vector::METHOD_VECTOR_WARD>
(col_s, diss, clus_res);
// diss.postprocess(clus_res);
// generic_linkage<method_codes::METHOD_METR_WARD>
// (col_s, diss_vec.data(), members.data(), clus_res);
diss.postprocess(clus_res);

vec_t<double> Z((col_s - 1) * 4, 0);

Expand Down
142 changes: 58 additions & 84 deletions include/DataFrame/Internals/fastcluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -976,10 +976,10 @@ f_median(double *const b, double a, double c_4) {

// ----------------------------------------------------------------------------

template<method_codes method, typename t_members>
template<method_codes method>
static void NN_chain_core(std::size_t N,
double *const D,
t_members *const members,
std::size_t *const members,
ClusterResult &Z2) {

/*
Expand Down Expand Up @@ -1064,9 +1064,8 @@ static void NN_chain_core(std::size_t N,
// Remove the smaller index from the valid indices (active_nodes).
active_nodes.remove(idx1);

switch (method) {
case method_codes::METHOD_METR_SINGLE:
/*
if constexpr (method == method_codes::METHOD_METR_SINGLE) {
/*
Single linkage.
Characteristic: new distances are never longer than the
old distances.
Expand All @@ -1080,9 +1079,8 @@ static void NN_chain_core(std::size_t N,
// Update the distance matrix in the range (idx2, N).
for (i = active_nodes.succ[idx2]; i < N; i = active_nodes.succ[i])
f_single(&D_(idx2, i), D_(idx1, i));
break;

case method_codes::METHOD_METR_COMPLETE:
}
else if constexpr (method == method_codes::METHOD_METR_COMPLETE) {
/*
Complete linkage.
Characteristic: new distances are never shorter than the
Expand All @@ -1097,9 +1095,8 @@ static void NN_chain_core(std::size_t N,
// Update the distance matrix in the range (idx2, N).
for (i = active_nodes.succ[idx2]; i < N; i = active_nodes.succ[i])
f_complete(&D_(idx2, i), D_(idx1, i));
break;

case method_codes::METHOD_METR_AVERAGE: {
}
else if constexpr (method == method_codes::METHOD_METR_AVERAGE) {
/*
Average linkage.
Shorter and longer distances can occur.
Expand All @@ -1116,10 +1113,8 @@ static void NN_chain_core(std::size_t N,
// Update the distance matrix in the range (idx2, N).
for (i = active_nodes.succ[idx2]; i < N; i = active_nodes.succ[i])
f_average(&D_(idx2, i), D_(idx1, i), s, t);
break;
}

case method_codes::METHOD_METR_WEIGHTED:
else if constexpr (method == method_codes::METHOD_METR_WEIGHTED) {
/*
Weighted linkage.
Shorter and longer distances can occur.
Expand All @@ -1133,9 +1128,8 @@ static void NN_chain_core(std::size_t N,
// Update the distance matrix in the range (idx2, N).
for (i = active_nodes.succ[idx2]; i < N; i = active_nodes.succ[i])
f_weighted(&D_(idx2, i), D_(idx1, i));
break;

case method_codes::METHOD_METR_WARD:
}
else if constexpr (method == method_codes::METHOD_METR_WEIGHTED) {
/*
Ward linkage.
Shorter and longer distances can occur, not smaller than min(d1, d2)
Expand All @@ -1154,9 +1148,8 @@ static void NN_chain_core(std::size_t N,
for (i = active_nodes.succ[idx2]; i < N; i = active_nodes.succ[i])
f_ward(&D_(idx2, i), D_(idx1, i), min,
size1, size2, static_cast<double>(members[i]));
break;

default:
}
else {
throw std::runtime_error("NN_chain_core(): Invalid method.");
}
}
Expand Down Expand Up @@ -1342,10 +1335,10 @@ class BinaryMinHeap {

// ----------------------------------------------------------------------------

template <method_codes method, typename t_members>
template <method_codes method>
static void generic_linkage(std::size_t N,
double *const D,
t_members *const members,
std::size_t *const members,
ClusterResult &Z2) {

/*
Expand Down Expand Up @@ -1475,8 +1468,7 @@ static void generic_linkage(std::size_t N,
row_repr[idx2] = N + i;

// Update the distance matrix
switch (method) {
case method_codes::METHOD_METR_SINGLE:
if constexpr (method == method_codes::METHOD_METR_SINGLE) {
/*
Single linkage.
Characteristic: new distances are never longer than the
Expand Down Expand Up @@ -1512,9 +1504,8 @@ static void generic_linkage(std::size_t N,
}
nn_distances.update_leq(idx2, min);
}
break;

case method_codes::METHOD_METR_COMPLETE:
}
else if constexpr (method == method_codes::METHOD_METR_COMPLETE) {
/*
Complete linkage.
Characteristic: new distances are never shorter than
Expand All @@ -1532,9 +1523,8 @@ static void generic_linkage(std::size_t N,
// Update the distance matrix in the range (idx2, N).
for (j = active_nodes.succ[idx2]; j < N; j = active_nodes.succ[j])
f_complete(&D_(idx2, j), D_(idx1, j));
break;

case method_codes::METHOD_METR_AVERAGE: {
}
else if constexpr (method == method_codes::METHOD_METR_AVERAGE) {
/*
Average linkage.
Shorter and longer distances can occur.
Expand Down Expand Up @@ -1572,10 +1562,8 @@ static void generic_linkage(std::size_t N,
}
nn_distances.update(idx2, min);
}
break;
}

case method_codes::METHOD_METR_WEIGHTED:
else if constexpr (method == method_codes::METHOD_METR_WEIGHTED) {
/*
Weighted linkage.
Shorter and longer distances can occur.
Expand Down Expand Up @@ -1610,9 +1598,8 @@ static void generic_linkage(std::size_t N,
}
nn_distances.update(idx2, min);
}
break;

case method_codes::METHOD_METR_WARD:
}
else if constexpr (method == method_codes::METHOD_METR_WARD) {
/*
Ward linkage.
Shorter and longer distances can occur, not smaller
Expand Down Expand Up @@ -1652,9 +1639,8 @@ static void generic_linkage(std::size_t N,
}
nn_distances.update(idx2, min);
}
break;

case method_codes::METHOD_METR_CENTROID: {
}
else if constexpr (method == method_codes::METHOD_METR_CENTROID) {
/*
Centroid linkage.
Shorter and longer distances can occur, not bigger than max(d1, d2)
Expand Down Expand Up @@ -1698,10 +1684,8 @@ static void generic_linkage(std::size_t N,
}
nn_distances.update(idx2, min);
}
break;
}

case method_codes::METHOD_METR_MEDIAN: {
else if constexpr (method == method_codes::METHOD_METR_MEDIAN) {
/*
Median linkage.
Shorter and longer distances can occur, not bigger than max(d1, d2)
Expand Down Expand Up @@ -1743,10 +1727,8 @@ static void generic_linkage(std::size_t N,
}
nn_distances.update(idx2, min);
}
break;
}

default:
else {
throw std::runtime_error("generic_linkage(): Invalid method.");
}
}
Expand Down Expand Up @@ -1854,30 +1836,27 @@ generic_linkage_vector(std::size_t N, Dissimilarity &dist, ClusterResult &Z2) {
// Find the nearest neighbor of each point.
// n_nghbr[i] = argmin_{j > i} D(i, j) for i in range(N-1)
for (i = 0; i < N_1; ++i) {
min = std::numeric_limits<double>::infinity();

std::size_t idx;

min = std::numeric_limits<double>::infinity();
for (idx = j = i + 1; j < N; ++j) {
double tmp;

switch (method) {
case method_codes_vector::METHOD_VECTOR_WARD:
if constexpr (method == method_codes_vector::METHOD_VECTOR_WARD) {
tmp = dist.ward_initial(i, j);
break;
default:
}
else {
tmp = dist.sqeuclidean(i, j);
}
if (tmp < min) {
min = tmp;
idx = j;
}
}
switch (method) {
case method_codes_vector::METHOD_VECTOR_WARD:
if constexpr (method == method_codes_vector::METHOD_VECTOR_WARD) {
mindist[i] = Dissimilarity::ward_initial_conversion(min);
break;
default:
}
else {
mindist[i] = min;
}
n_nghbr[i] = idx;
Expand All @@ -1895,8 +1874,7 @@ generic_linkage_vector(std::size_t N, Dissimilarity &dist, ClusterResult &Z2) {
// Recompute the minimum mindist[idx1] and n_nghbr[idx1].
// exists, maximally N-1
n_nghbr[idx1] = j = active_nodes.succ[idx1];
switch (method) {
case method_codes_vector::METHOD_VECTOR_WARD:
if constexpr (method == method_codes_vector::METHOD_VECTOR_WARD) {
min = dist.ward(idx1, j);
for (j = active_nodes.succ[j]; j < N;
j = active_nodes.succ[j]) {
Expand All @@ -1907,8 +1885,8 @@ generic_linkage_vector(std::size_t N, Dissimilarity &dist, ClusterResult &Z2) {
n_nghbr[idx1] = j;
}
}
break;
default:
}
else {
min = dist.sqeuclidean(idx1, j);
for (j = active_nodes.succ[j]; j < N;
j = active_nodes.succ[j]) {
Expand All @@ -1935,15 +1913,15 @@ generic_linkage_vector(std::size_t N, Dissimilarity &dist, ClusterResult &Z2) {

Z2.append(node1, node2, mindist[idx1]);

switch (method) {
case method_codes_vector::METHOD_VECTOR_WARD:
case method_codes_vector::METHOD_VECTOR_CENTROID:
if constexpr (method == method_codes_vector::METHOD_VECTOR_WARD ||
method == method_codes_vector::METHOD_VECTOR_CENTROID) {
dist.merge_inplace(idx1, idx2);
break;
case method_codes_vector::METHOD_VECTOR_MEDIAN:
}
else if constexpr (method ==
method_codes_vector::METHOD_VECTOR_MEDIAN) {
dist.merge_inplace_weighted(idx1, idx2);
break;
default:
}
else {
throw std::runtime_error("generic_linkage_vector(): "
"Invalid method.");
}
Expand All @@ -1954,8 +1932,7 @@ generic_linkage_vector(std::size_t N, Dissimilarity &dist, ClusterResult &Z2) {
active_nodes.remove(idx1); // TBD later!!!

// Update the distance matrix
switch (method) {
case method_codes_vector::METHOD_VECTOR_WARD:
if constexpr (method == method_codes_vector::METHOD_VECTOR_WARD) {
/*
Ward linkage.
Shorter and longer distances can occur, not smaller
Expand Down Expand Up @@ -1995,9 +1972,8 @@ generic_linkage_vector(std::size_t N, Dissimilarity &dist, ClusterResult &Z2) {
}
nn_distances.update(idx2, min);
}
break;

default:
}
else {
/*
Centroid and median linkage.
Shorter and longer distances can occur, not bigger
Expand Down Expand Up @@ -2114,8 +2090,7 @@ static void generic_linkage_vector_alternative(std::size_t N,
// Recompute the minimum mindist[idx1] and n_nghbr[idx1].
n_nghbr[idx1] = j = active_nodes.start;

switch (method) {
case method_codes_vector::METHOD_VECTOR_WARD:
if constexpr (method == method_codes_vector::METHOD_VECTOR_WARD) {
min = dist.ward_extended(idx1, j);
for (j = active_nodes.succ[j]; j < idx1;
j = active_nodes.succ[j]) {
Expand All @@ -2126,8 +2101,8 @@ static void generic_linkage_vector_alternative(std::size_t N,
n_nghbr[idx1] = j;
}
}
break;
default:
}
else {
min = dist.sqeuclidean_extended(idx1, j);
for (j = active_nodes.succ[j]; j < idx1;
j = active_nodes.succ[j]) {
Expand All @@ -2152,17 +2127,16 @@ static void generic_linkage_vector_alternative(std::size_t N,
Z2.append(idx1, idx2, mindist[idx1]);

if (i < (2 * N_1)) {
switch (method) {
case method_codes_vector::METHOD_VECTOR_WARD:
case method_codes_vector::METHOD_VECTOR_CENTROID:
if constexpr (method == method_codes_vector::METHOD_VECTOR_WARD ||
method ==
method_codes_vector::METHOD_VECTOR_CENTROID) {
dist.merge(idx1, idx2, i);
break;

case method_codes_vector::METHOD_VECTOR_MEDIAN:
}
else if constexpr (method ==
method_codes_vector::METHOD_VECTOR_MEDIAN) {
dist.merge_weighted(idx1, idx2, i);
break;

default:
}
else {
throw std::runtime_error(
"generic_linkage_vector_alternative(): Invalid method.");
}
Expand Down
Loading

0 comments on commit 101e81c

Please sign in to comment.