Skip to content

Commit

Permalink
Made Matrix::covariance() multithreaded
Browse files Browse the repository at this point in the history
  • Loading branch information
hosseinmoein committed Dec 12, 2024
1 parent 18fc7df commit d6fa594
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions include/DataFrame/Utils/Matrix.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -1226,24 +1226,38 @@ covariance(bool is_unbiased) const {
if (denom <= value_type(0))
throw DataFrameError("Matrix::covariance(): Not solvable");

Matrix result (cols(), cols(), T(0));
Matrix result (cols(), cols(), T(0));
auto lbd =
[&result, this, denom](auto begin, auto end) -> void {
for (size_type cr = begin; cr < end; ++cr) {
value_type mean { 0 };

for (size_type cr = 0; cr < cols(); ++cr) {
value_type col_mean { 0 };
for (size_type r = 0; r < rows(); ++r)
mean += at(r, cr);
mean /= value_type(rows());

for (size_type r = 0; r < rows(); ++r)
col_mean += at(r, cr);
col_mean /= value_type(rows());
for (size_type c = cr; c < cols(); ++c) {
value_type var_covar { 0 };

for (size_type c = 0; c < cols(); ++c) {
value_type var_covar { 0 };
for (size_type r = 0; r < rows(); ++r)
var_covar += (at(r, cr) - mean) * (at(r, c) - mean);

for (size_type r = 0; r < rows(); ++r)
var_covar += (at(r, cr) - col_mean) * (at(r, c) - col_mean);
result(cr, c) = result(c, cr) = var_covar / denom;
}
}
};
const long thread_level =
(cols() >= 50L || rows() >= 100'000L)
? ThreadGranularity::get_thread_level() : 0;

result(cr, c) = var_covar / denom;
}
if (thread_level > 2) {
auto futures =
ThreadGranularity::thr_pool_.parallel_loop(0L, cols(),
std::move(lbd));

for (auto &fut : futures) fut.get();
}
else lbd(0L, cols());

return (result);
}
Expand Down

0 comments on commit d6fa594

Please sign in to comment.