Skip to content

Commit

Permalink
Use parallel sort for quantile calculation when appropriate. (#11275)
Browse files Browse the repository at this point in the history
- Drop the duplicated `omp_in_parallel` check.
- Use parallel sort instead of parallel leaf values based on a heuristic.
  • Loading branch information
trivialfis authored Feb 24, 2025
1 parent 5006fe7 commit 688c2f5
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 31 deletions.
7 changes: 5 additions & 2 deletions src/common/algorithm.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2025, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_ALGORITHM_H_
#define XGBOOST_COMMON_ALGORITHM_H_
#include <algorithm> // upper_bound, stable_sort, sort, max
#include <cinttypes> // size_t
#include <cstddef> // size_t
#include <functional> // less
#include <iterator> // iterator_traits, distance
#include <vector> // vector
Expand All @@ -16,6 +16,9 @@
#if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && \
!defined(__APPLE__) && __has_include(<omp.h>) && __has_include(<parallel/algorithm>)
#define GCC_HAS_PARALLEL 1
constexpr bool kHasParallelStableSort = true;
#else
constexpr bool kHasParallelStableSort = false;
#endif // GLIC_VERSION

#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
Expand Down
39 changes: 18 additions & 21 deletions src/common/stats.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2022-2024, XGBoost Contributors
* Copyright 2022-2025, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_STATS_H_
#define XGBOOST_COMMON_STATS_H_
Expand Down Expand Up @@ -32,8 +32,9 @@ namespace common {
*
* \return The result of interpolation.
*/
template <typename Iter>
float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const& end) {
template <typename Iter,
typename R = std::remove_reference_t<typename std::iterator_traits<Iter>::value_type>>
[[nodiscard]] R Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const& end) {
CHECK(alpha >= 0 && alpha <= 1);
auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) {
Expand All @@ -42,15 +43,12 @@ float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const&

std::vector<std::size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
if (omp_in_parallel()) {
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });

auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
auto val = [&](size_t i) {
return *(begin + sorted_idx[i]);
};
static_assert(std::is_same_v<decltype(val(0)), float>);

if (alpha <= (1 / (n + 1))) {
Expand All @@ -77,23 +75,22 @@ float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const&
* See https://aakinshin.net/posts/weighted-quantiles/ for some discussions on computing
* weighted quantile with interpolation.
*/
template <typename Iter, typename WeightIter>
float WeightedQuantile(Context const* ctx, double alpha, Iter begin, Iter end, WeightIter w_begin) {
template <typename Iter, typename WeightIter,
typename R = std::remove_reference_t<typename std::iterator_traits<Iter>::value_type>>
[[nodiscard]] R WeightedQuantile(Context const* ctx, double alpha, Iter begin, Iter end,
WeightIter w_begin) {
auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) {
return std::numeric_limits<float>::quiet_NaN();
}
std::vector<size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
if (omp_in_parallel()) {
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });

auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
auto val = [&](size_t i) {
return *(begin + sorted_idx[i]);
};

std::vector<float> weight_cdf(n); // S_n
// weighted cdf is sorted during construction
Expand Down
17 changes: 13 additions & 4 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2024, XGBoost Contributors
* Copyright 2019-2025, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_THREADING_UTILS_H_
#define XGBOOST_COMMON_THREADING_UTILS_H_
Expand All @@ -14,6 +14,7 @@
#include <new> // for bad_alloc
#include <thread> // for thread
#include <type_traits> // for is_signed, conditional_t, is_integral_v, invoke_result_t
#include <utility> // for forward
#include <vector> // for vector

#include "xgboost/logging.h"
Expand Down Expand Up @@ -181,7 +182,15 @@ struct Sched {
};

template <typename Index, typename Func>
void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
void ParallelFor(Index size, std::int32_t n_threads, Sched sched, Func&& fn) {
if (n_threads == 1) {
// early exit
for (Index i = 0; i < size; ++i) {
fn(i);
}
return;
}

#if defined(_MSC_VER)
// msvc doesn't support unsigned integer as openmp index.
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
Expand Down Expand Up @@ -240,8 +249,8 @@ void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
}

template <typename Index, typename Func>
void ParallelFor(Index size, int32_t n_threads, Func fn) {
ParallelFor(size, n_threads, Sched::Static(), fn);
void ParallelFor(Index size, std::int32_t n_threads, Func&& fn) {
ParallelFor(size, n_threads, Sched::Static(), std::forward<Func>(fn));
}

inline std::int32_t OmpGetThreadLimit() {
Expand Down
23 changes: 21 additions & 2 deletions src/objective/adaptive.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2022-2024, XGBoost Contributors
* Copyright 2022-2025, XGBoost Contributors
*/
#include "adaptive.h"

Expand Down Expand Up @@ -104,10 +104,29 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
predt.Size() / info.num_row_);

// A heuristic to use parallel sort. If we use multiple threads here, the sorting is
// performed using a single thread as openmp cannot allocate new threads inside a
// parallel region.
std::int32_t n_threads;
if constexpr (kHasParallelStableSort) {
CHECK_GE(h_node_ptr.size(), 1);
auto it = common::MakeIndexTransformIter(
[&](std::size_t i) { return h_node_ptr[i + 1] - h_node_ptr[i]; });
n_threads = std::any_of(it, it + h_node_ptr.size() - 1,
[](auto n) {
constexpr std::size_t kNeedParallelSort = 1ul << 19;
return n > kNeedParallelSort;
})
? 1
: ctx->Threads();
} else {
n_threads = ctx->Threads();
}

collective::ApplyWithLabels(
ctx, info, static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float), [&] {
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
common::ParallelFor(quantiles.size(), n_threads, [&](size_t k) {
auto nidx = h_node_idx[k];
CHECK(tree[nidx].IsLeaf());
CHECK_LT(k + 1, h_node_ptr.size());
Expand Down
5 changes: 3 additions & 2 deletions tests/cpp/common/test_threading_utils.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
/**
* Copyright 2019-2024, XGBoost Contributors
*/
#include <dmlc/omp.h> // for omp_in_parallel
#include <gtest/gtest.h>

#include <cstddef> // std::size_t
#include <cstddef> // for std::size_t

#include "../../../src/common/threading_utils.h" // BlockedSpace2d,ParallelFor2d,ParallelFor
#include "dmlc/omp.h" // omp_in_parallel
#include "xgboost/context.h" // Context

namespace xgboost::common {
Expand Down Expand Up @@ -99,6 +99,7 @@ TEST(ParallelFor, Basic) {
ASSERT_LT(i, n);
});
ASSERT_FALSE(omp_in_parallel());
ParallelFor(n, 1, [&](auto) { ASSERT_FALSE(omp_in_parallel()); });
}

TEST(OmpGetNumThreads, Max) {
Expand Down

0 comments on commit 688c2f5

Please sign in to comment.