-
Notifications
You must be signed in to change notification settings - Fork 451
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Agglomerative clustering. (#1384)
We use the open-source implementation from https://github.com/cdalitz/hclust-cpp
- Loading branch information
1 parent
bc08160
commit 70568c2
Showing
12 changed files
with
343 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
function(download_hclust_cpp) | ||
include(FetchContent) | ||
|
||
# The latest commit as of 2024.09.29 | ||
set(hclust_cpp_URL "https://github.com/csukuangfj/hclust-cpp/archive/refs/tags/2024-09-29.tar.gz") | ||
set(hclust_cpp_HASH "SHA256=abab51448a3cb54272aae07522970306e0b2cc6479d59d7b19e7aee4d6cedd33") | ||
|
||
# If you don't have access to the Internet, | ||
# please pre-download hclust-cpp | ||
set(possible_file_locations | ||
$ENV{HOME}/Downloads/hclust-cpp-2024-09-29.tar.gz | ||
${CMAKE_SOURCE_DIR}/hclust-cpp-2024-09-29.tar.gz | ||
${CMAKE_BINARY_DIR}/hclust-cpp-2024-09-29.tar.gz | ||
/tmp/hclust-cpp-2024-09-29.tar.gz | ||
/star-fj/fangjun/download/github/hclust-cpp-2024-09-29.tar.gz | ||
) | ||
|
||
foreach(f IN LISTS possible_file_locations) | ||
if(EXISTS ${f}) | ||
set(hclust_cpp_URL "${f}") | ||
file(TO_CMAKE_PATH "${hclust_cpp_URL}" hclust_cpp_URL) | ||
message(STATUS "Found local downloaded hclust_cpp: ${hclust_cpp_URL}") | ||
break() | ||
endif() | ||
endforeach() | ||
|
||
FetchContent_Declare(hclust_cpp | ||
URL | ||
${hclust_cpp_URL} | ||
${hclust_cpp_URL2} | ||
URL_HASH ${hclust_cpp_HASH} | ||
) | ||
|
||
FetchContent_GetProperties(hclust_cpp) | ||
if(NOT hclust_cpp_POPULATED) | ||
message(STATUS "Downloading hclust_cpp from ${hclust_cpp_URL}") | ||
FetchContent_Populate(hclust_cpp) | ||
endif() | ||
|
||
message(STATUS "hclust_cpp is downloaded to ${hclust_cpp_SOURCE_DIR}") | ||
message(STATUS "hclust_cpp's binary dir is ${hclust_cpp_BINARY_DIR}") | ||
include_directories(${hclust_cpp_SOURCE_DIR}) | ||
endfunction() | ||
|
||
download_hclust_cpp() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// sherpa-onnx/csrc/fast-clustering-config.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/fast-clustering-config.h" | ||
|
||
#include <sstream> | ||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
|
||
namespace sherpa_onnx { | ||
std::string FastClusteringConfig::ToString() const { | ||
std::ostringstream os; | ||
|
||
os << "FastClusteringConfig("; | ||
os << "num_clusters=" << num_clusters << ", "; | ||
os << "threshold=" << threshold << ")"; | ||
|
||
return os.str(); | ||
} | ||
|
||
void FastClusteringConfig::Register(ParseOptions *po) { | ||
std::string prefix = "ctc"; | ||
ParseOptions p(prefix, po); | ||
|
||
p.Register("num-clusters", &num_clusters, | ||
"Number of cluster. If greater than 0, then --cluster-thresold is " | ||
"ignored"); | ||
|
||
p.Register("cluster-threshold", &threshold, | ||
"If --num-clusters is not specified, then it specifies the " | ||
"distance threshold for clustering."); | ||
} | ||
|
||
bool FastClusteringConfig::Validate() const { | ||
if (num_clusters < 1 && threshold < 0) { | ||
SHERPA_ONNX_LOGE("Please provide either num_clusters or threshold"); | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// sherpa-onnx/csrc/fast-clustering-config.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ | ||
#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ | ||
|
||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/parse-options.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct FastClusteringConfig { | ||
// If greater than 0, then threshold is ignored | ||
int32_t num_clusters = -1; | ||
|
||
// distance threshold | ||
float threshold = 0.5; | ||
|
||
std::string ToString() const; | ||
|
||
void Register(ParseOptions *po); | ||
bool Validate() const; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// sherpa-onnx/csrc/fast-clustering-test.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/fast-clustering.h" | ||
|
||
#include <vector> | ||
|
||
#include "gtest/gtest.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
TEST(FastClustering, TestTwoClusters) { | ||
std::vector<float> features = { | ||
// point 0 | ||
0.1, | ||
0.1, | ||
// point 2 | ||
0.4, | ||
-0.5, | ||
// point 3 | ||
0.6, | ||
-0.7, | ||
// point 1 | ||
0.2, | ||
0.3, | ||
}; | ||
|
||
FastClusteringConfig config; | ||
config.num_clusters = 2; | ||
|
||
FastClustering clustering(config); | ||
auto labels = clustering.Cluster(features.data(), 4, 2); | ||
int32_t k = 0; | ||
for (auto i : labels) { | ||
std::cout << "point " << k << ": label " << i << "\n"; | ||
++k; | ||
} | ||
} | ||
|
||
TEST(FastClustering, TestClusteringWithThreshold) { | ||
std::vector<float> features = { | ||
// point 0 | ||
0.1, | ||
0.1, | ||
// point 2 | ||
0.4, | ||
-0.5, | ||
// point 3 | ||
0.6, | ||
-0.7, | ||
// point 1 | ||
0.2, | ||
0.3, | ||
}; | ||
|
||
FastClusteringConfig config; | ||
config.threshold = 0.5; | ||
|
||
FastClustering clustering(config); | ||
auto labels = clustering.Cluster(features.data(), 4, 2); | ||
int32_t k = 0; | ||
for (auto i : labels) { | ||
std::cout << "point " << k << ": label " << i << "\n"; | ||
++k; | ||
} | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// sherpa-onnx/csrc/fast-clustering.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/fast-clustering.h" | ||
|
||
#include <vector> | ||
|
||
#include "Eigen/Dense" | ||
#include "fastcluster-all-in-one.h" // NOLINT | ||
|
||
namespace sherpa_onnx { | ||
|
||
class FastClustering::Impl { | ||
public: | ||
explicit Impl(const FastClusteringConfig &config) : config_(config) {} | ||
|
||
std::vector<int32_t> Cluster(float *features, int32_t num_rows, | ||
int32_t num_cols) { | ||
if (num_rows <= 0) { | ||
return {}; | ||
} | ||
|
||
if (num_rows == 1) { | ||
return {0}; | ||
} | ||
|
||
Eigen::Map< | ||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> | ||
m(features, num_rows, num_cols); | ||
m.rowwise().normalize(); | ||
|
||
std::vector<double> distance((num_rows * (num_rows - 1)) / 2); | ||
|
||
int32_t k = 0; | ||
for (int32_t i = 0; i != num_rows; ++i) { | ||
auto v = m.row(i); | ||
for (int32_t j = i + 1; j != num_rows; ++j) { | ||
double cosine_similarity = v.dot(m.row(j)); | ||
double consine_dissimilarity = 1 - cosine_similarity; | ||
|
||
if (consine_dissimilarity < 0) { | ||
consine_dissimilarity = 0; | ||
} | ||
|
||
distance[k] = consine_dissimilarity; | ||
++k; | ||
} | ||
} | ||
|
||
std::vector<int32_t> merge(2 * (num_rows - 1)); | ||
std::vector<double> height(num_rows - 1); | ||
|
||
fastclustercpp::hclust_fast(num_rows, distance.data(), | ||
fastclustercpp::HCLUST_METHOD_SINGLE, | ||
merge.data(), height.data()); | ||
|
||
std::vector<int32_t> labels(num_rows); | ||
if (config_.num_clusters > 0) { | ||
fastclustercpp::cutree_k(num_rows, merge.data(), config_.num_clusters, | ||
labels.data()); | ||
} else { | ||
fastclustercpp::cutree_cdist(num_rows, merge.data(), height.data(), | ||
config_.threshold, labels.data()); | ||
} | ||
|
||
return labels; | ||
} | ||
|
||
private: | ||
FastClusteringConfig config_; | ||
}; | ||
|
||
FastClustering::FastClustering(const FastClusteringConfig &config) | ||
: impl_(std::make_unique<Impl>(config)) {} | ||
|
||
FastClustering::~FastClustering() = default; | ||
|
||
std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows, | ||
int32_t num_cols) { | ||
return impl_->Cluster(features, num_rows, num_cols); | ||
} | ||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// sherpa-onnx/csrc/fast-clustering.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ | ||
#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "sherpa-onnx/csrc/fast-clustering-config.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class FastClustering { | ||
public: | ||
explicit FastClustering(const FastClusteringConfig &config); | ||
~FastClustering(); | ||
|
||
/** | ||
* @param features Pointer to a 2-D feature matrix in row major. Each row | ||
* is a feature frame. It is changed in-place. We will | ||
* convert each feature frame to a normalized vector. | ||
* That is, the L2-norm of each vector will be equal to 1. | ||
* It uses cosine dissimilarity, | ||
* which is 1 - (cosine similarity) | ||
* @param num_rows Number of feature frames | ||
* @param num-cols The feature dimension. | ||
* | ||
* @return Return a vector of size num_rows. ans[i] contains the label | ||
* for the i-th feature frame, i.e., the i-th row of the feature | ||
* matrix. | ||
*/ | ||
std::vector<int32_t> Cluster(float *features, int32_t num_rows, | ||
int32_t num_cols); | ||
|
||
private: | ||
class Impl; | ||
std::unique_ptr<Impl> impl_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.