diff --git a/cpp/include/kvikio/defaults.hpp b/cpp/include/kvikio/defaults.hpp index 722986c362..aff8aa5ea8 100644 --- a/cpp/include/kvikio/defaults.hpp +++ b/cpp/include/kvikio/defaults.hpp @@ -111,7 +111,7 @@ std::tuple getenv_or( */ class defaults { private: - BS_thread_pool _thread_pool{get_num_threads_from_env()}; + ThreadPool _thread_pool{get_num_threads_from_env()}; CompatMode _compat_mode; std::size_t _task_size; std::size_t _gds_threshold; @@ -212,7 +212,7 @@ class defaults { * * @return The default thread pool instance. */ - [[nodiscard]] static BS_thread_pool& thread_pool(); + [[nodiscard]] static ThreadPool& thread_pool(); /** * @brief Get the number of threads in the default thread pool. diff --git a/cpp/include/kvikio/detail/function_wrapper.hpp b/cpp/include/kvikio/detail/function_wrapper.hpp new file mode 100644 index 0000000000..4f9d95cd94 --- /dev/null +++ b/cpp/include/kvikio/detail/function_wrapper.hpp @@ -0,0 +1,92 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace kvikio::detail { +/** + * @brief Type-erased function wrapper that can hold a copyable or move-only callable with signature + * void(). Unlike std::function, this wrapper is move-only and cannot be copied. + * + * @todo Use small buffer optimization to avoid heap allocation for small callables. + * @note This class will be superseded by C++23's std::move_only_function. + */ +class FunctionWrapper { + private: + struct InnerBase { + virtual void operator()() = 0; + + virtual ~InnerBase() = default; + }; + + template + struct Inner : InnerBase { + using F_decay = std::decay_t; + static_assert(std::is_invocable_r_v); + + explicit Inner(F&& f) : _f(std::forward(f)) {} + + void operator()() override { std::invoke(_f); } + + ~Inner() override = default; + + F_decay _f; + }; + + std::unique_ptr _callable; + + public: + /** + * @brief Construct a function wrapper from a callable object. The callable must be invocable with + * no arguments and return void. It can be either copyable or move-only (e.g., a lambda capturing + * std::unique_ptr). + * + * @tparam F Callable type. + * @param f Callable object to wrap. Will be moved or copied into the wrapper. + */ + template + FunctionWrapper(F&& f) : _callable(std::make_unique>(std::forward(f))) + { + } + + /** + * @brief Default constructor. Creates an empty wrapper with no callable target. + */ + FunctionWrapper() = default; + + FunctionWrapper(FunctionWrapper&&) noexcept = default; + FunctionWrapper& operator=(FunctionWrapper&&) noexcept = default; + + FunctionWrapper(const FunctionWrapper&) = delete; + FunctionWrapper& operator=(const FunctionWrapper&) = delete; + + /** + * @brief Invoke the wrapped callable. + * + * @exception std::bad_function_call if the wrapper is empty (default-constructed or moved-from). + */ + void operator()() + { + if (!_callable) { throw std::bad_function_call(); } + _callable->operator()(); + } + + /** + * @brief Check whether the wrapper contains a callable target. + * + * @return true if the wrapper contains a callable, false if it is empty. + */ + explicit operator bool() const noexcept { return _callable != nullptr; } + + /** + * @brief Reset the wrapper to an empty state, destroying the contained callable if any. + */ + void reset() noexcept { _callable.reset(); } +}; + +} // namespace kvikio::detail diff --git a/cpp/include/kvikio/detail/parallel_operation.hpp b/cpp/include/kvikio/detail/parallel_operation.hpp index a4489da8e5..1d3c43d287 100644 --- a/cpp/include/kvikio/detail/parallel_operation.hpp +++ b/cpp/include/kvikio/detail/parallel_operation.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace kvikio { @@ -75,6 +76,7 @@ std::future submit_task(F op, std::size_t size, std::size_t file_offset, std::size_t devPtr_offset, + ThreadPool* thread_pool = &defaults::thread_pool(), std::uint64_t nvtx_payload = 0ull, nvtx_color_type nvtx_color = NvtxManager::default_color()) { @@ -85,7 +87,7 @@ std::future submit_task(F op, decltype(file_offset), decltype(devPtr_offset)>); - return defaults::thread_pool().submit_task([=] { + return thread_pool->submit_task([=] { KVIKIO_NVTX_SCOPED_RANGE("task", nvtx_payload, nvtx_color); return op(buf, size, file_offset, devPtr_offset); }); @@ -101,12 +103,13 @@ std::future submit_task(F op, template std::future submit_move_only_task( F op_move_only, + ThreadPool* thread_pool = &defaults::thread_pool(), std::uint64_t nvtx_payload = 0ull, nvtx_color_type nvtx_color = NvtxManager::default_color()) { static_assert(std::is_invocable_r_v); auto op_copyable = make_copyable_lambda(std::move(op_move_only)); - return defaults::thread_pool().submit_task([=] { + return thread_pool->submit_task([=] { KVIKIO_NVTX_SCOPED_RANGE("task", nvtx_payload, nvtx_color); return op_copyable(); }); @@ -124,6 +127,10 @@ std::future submit_move_only_task( * @param size Number of bytes to read or write. * @param file_offset Byte offset to the start of the file. * @param task_size Size of each task in bytes. + * @param devPtr_offset Offset relative to the `devPtr_base` pointer. This parameter should be used + * only with registered buffers. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. * @return A future to be used later to check if the operation has finished its execution. */ template @@ -133,10 +140,12 @@ std::future parallel_io(F op, std::size_t file_offset, std::size_t task_size, std::size_t devPtr_offset, + ThreadPool* thread_pool = &defaults::thread_pool(), std::uint64_t call_idx = 0, nvtx_color_type nvtx_color = NvtxManager::default_color()) { KVIKIO_EXPECT(task_size > 0, "`task_size` must be positive", std::invalid_argument); + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); static_assert(std::is_invocable_r_v parallel_io(F op, // Single-task guard if (task_size >= size || get_page_size() >= size) { - return detail::submit_task(op, buf, size, file_offset, devPtr_offset, call_idx, nvtx_color); + return detail::submit_task( + op, buf, size, file_offset, devPtr_offset, thread_pool, call_idx, nvtx_color); } std::vector> tasks; @@ -154,8 +164,8 @@ std::future parallel_io(F op, // 1) Submit all tasks but the last one. These are all `task_size` sized tasks. while (size > task_size) { - tasks.push_back( - detail::submit_task(op, buf, task_size, file_offset, devPtr_offset, call_idx, nvtx_color)); + tasks.push_back(detail::submit_task( + op, buf, task_size, file_offset, devPtr_offset, thread_pool, call_idx, nvtx_color)); file_offset += task_size; devPtr_offset += task_size; size -= task_size; @@ -170,7 +180,7 @@ std::future parallel_io(F op, } return ret; }; - return detail::submit_move_only_task(std::move(last_task), call_idx, nvtx_color); + return detail::submit_move_only_task(std::move(last_task), thread_pool, call_idx, nvtx_color); } } // namespace kvikio diff --git a/cpp/include/kvikio/file_handle.hpp b/cpp/include/kvikio/file_handle.hpp index e74b8e3e20..b2b61a8485 100644 --- a/cpp/include/kvikio/file_handle.hpp +++ b/cpp/include/kvikio/file_handle.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace kvikio { @@ -228,6 +229,8 @@ class FileHandle { * in the null stream. When in KvikIO's compatibility mode or when accessing host memory, the * operation is always default stream ordered like the rest of the non-async CUDA API. In this * case, the value of `sync_default_stream` is ignored. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. * @return Future that on completion returns the size of bytes that were successfully read. * * @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of @@ -238,7 +241,8 @@ class FileHandle { std::size_t file_offset = 0, std::size_t task_size = defaults::task_size(), std::size_t gds_threshold = defaults::gds_threshold(), - bool sync_default_stream = true); + bool sync_default_stream = true, + ThreadPool* thread_pool = &defaults::thread_pool()); /** * @brief Writes specified bytes from device or host memory into the file in parallel. @@ -265,6 +269,8 @@ class FileHandle { * in the null stream. When in KvikIO's compatibility mode or when accessing host memory, the * operation is always default stream ordered like the rest of the non-async CUDA API. In this * case, the value of `sync_default_stream` is ignored. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. * @return Future that on completion returns the size of bytes that were successfully written. * * @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of @@ -275,7 +281,8 @@ class FileHandle { std::size_t file_offset = 0, std::size_t task_size = defaults::task_size(), std::size_t gds_threshold = defaults::gds_threshold(), - bool sync_default_stream = true); + bool sync_default_stream = true, + ThreadPool* thread_pool = &defaults::thread_pool()); /** * @brief Reads specified bytes from the file into the device memory asynchronously. diff --git a/cpp/include/kvikio/mmap.hpp b/cpp/include/kvikio/mmap.hpp index fe8b71cbf4..4f56fa880c 100644 --- a/cpp/include/kvikio/mmap.hpp +++ b/cpp/include/kvikio/mmap.hpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace kvikio { @@ -162,6 +163,8 @@ class MmapHandle { * specified, read starts from `offset` to the end of file * @param offset File offset * @param task_size Size of each task in bytes + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. * @return Future that on completion returns the size of bytes that were successfully read. * * @exception std::out_of_range if the read region specified by `offset` and `size` is @@ -174,7 +177,8 @@ class MmapHandle { std::future pread(void* buf, std::optional size = std::nullopt, std::size_t offset = 0, - std::size_t task_size = defaults::task_size()); + std::size_t task_size = defaults::task_size(), + ThreadPool* thread_pool = &defaults::thread_pool()); }; } // namespace kvikio diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 416e374291..2f41a728d1 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -13,6 +13,7 @@ #include #include +#include #include struct curl_slist; @@ -452,12 +453,15 @@ class RemoteHandle { * @param size Number of bytes to read. * @param file_offset File offset in bytes. * @param task_size Size of each task in bytes. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. * @return Future that on completion returns the size of bytes read, which is always `size`. */ std::future pread(void* buf, std::size_t size, std::size_t file_offset = 0, - std::size_t task_size = defaults::task_size()); + std::size_t task_size = defaults::task_size(), + ThreadPool* thread_pool = &defaults::thread_pool()); }; } // namespace kvikio diff --git a/cpp/include/kvikio/threadpool_roundrobin.hpp b/cpp/include/kvikio/threadpool_roundrobin.hpp new file mode 100644 index 0000000000..deb47c9e1b --- /dev/null +++ b/cpp/include/kvikio/threadpool_roundrobin.hpp @@ -0,0 +1,327 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +/** + * @file + * @brief A simple, header-only thread pool that uses per-thread task queues. Synchronization only + * exists between the pairs of the main thread and each worker thread, but not among the worker + * threads themselves. Inspired by the BS threadpool that KvikIO has been using. + */ + +namespace kvikio { +/** + * @brief Utility class for the calling thread. + */ +class ThisThread { + public: + /** + * @brief Check if the calling thread is from RoundRobinThreadPool. + * + * @return Boolean answer. + */ + static bool is_from_pool() { return get_thread_idx().has_value(); } + + /** + * @brief Get the index of the calling thread. + * + * If the calling thread is not from RoundRobinThreadPool, return std::nullopt. Otherwise, return + * the thread index ranging from 0 to (N-1) where N is the thread count. + * + * @return Index of the calling thread. + */ + static std::optional get_thread_idx() { return this_thread_idx; } + + private: + friend class RoundRobinThreadPool; + + /** + * @brief Set the index of the calling thread. + * + * @param thread_idx Index of the calling thread. + */ + static void set_thread_idx(std::size_t thread_idx) { this_thread_idx = thread_idx; } + + inline static thread_local std::optional this_thread_idx{std::nullopt}; +}; + +/** + * @brief + * + */ +struct Worker { + std::thread thread; + std::condition_variable task_available_cv; + std::condition_variable task_done_cv; + std::mutex task_mutex; + std::queue task_queue; + std::size_t tasks_in_flight{0}; + bool should_stop{false}; +}; + +/** + * @brief A simple thread pool that uses per-thread task queues. + * + * Each worker thread has their own task queue, mutex and condition variable. The per-thread + * synchronization primitives (mutex and condition variable) are shared with the main thread. Tasks + * are submitted to the worker threads in a round-robin fashion, unless the target thread index is + * specified by the user. + * + * Example: + * @code{.cpp} + * // Create a thread pool with 4 threads, and pass an optional callable with which to initialize + * // each worker thread. + * kvikio::RoundRobinThreadPool thread_pool{4, [] { + * // Initialize worker thread + * }}; + * + * // Submit the task to the thread pool. The worker thread is selected automatically in a + * // round-robin fashion. + * auto fut = thread_pool.submit_task([] { + * // Task logic + * }); + * + * // Submit the task to a specific thread. + * auto fut = thread_pool.submit_task_to_thread([] { + * // Task logic + * }); + * + * // Wait until the result is ready. + * auto result = fut.get(); + * @endcode + */ +class RoundRobinThreadPool { + public: + /** + * @brief Constructor. Create a thread pool. + * + * @tparam F Type of the user-defined worker thread initialization. + * @param num_threads Number of threads. + * @param worker_thread_init_func User-defined worker thread initialization. + */ + template + RoundRobinThreadPool(unsigned int num_threads, F&& worker_thread_init_func) + : _num_threads{num_threads}, _worker_thread_init_func{std::forward(worker_thread_init_func)} + { + create_threads(); + } + + /** + * @brief Constructor, without user-defined worker thread initialization. + * + * @param num_threads Number of threads. + */ + RoundRobinThreadPool(unsigned int num_threads) + : RoundRobinThreadPool(num_threads, detail::FunctionWrapper{}) + { + } + + /** + * @brief Destructor. Wait until all worker threads complete their tasks, then join the threads. + */ + ~RoundRobinThreadPool() + { + wait(); + destroy_threads(); + } + + unsigned int get_thread_count() { return _num_threads; } + + /** + * @brief Wait until all worker threads complete their tasks. Then join the threads, and + * reinitialize the thread pool with new threads. + * + * @tparam F Type of the user-defined worker thread initialization. + * @param num_threads Number of threads. + * @param worker_thread_init_func User-defined worker thread initialization. + */ + template + void reset(unsigned int num_threads, F&& worker_thread_init_func) + { + wait(); + destroy_threads(); + + _num_threads = num_threads; + _worker_thread_init_func = std::forward(worker_thread_init_func); + create_threads(); + } + + /** + * @brief Overload of reset(), without user-defined worker thread initialization. + * + * @param num_threads Number of threads. + */ + void reset(unsigned int num_threads) { reset(num_threads, detail::FunctionWrapper{}); } + + /** + * @brief Block the calling thread until all worker threads complete their tasks. + */ + void wait() + { + for (unsigned int thread_idx = 0; thread_idx < _num_threads; ++thread_idx) { + auto& worker = _workers[thread_idx]; + std::unique_lock lock(worker.task_mutex); + worker.task_done_cv.wait( + lock, [&] { return worker.task_queue.empty() && worker.tasks_in_flight == 0; }); + } + } + + /** + * @brief Get the number of threads from the thread pool. + * + * @return Thread count. + */ + unsigned int num_threads() const { return _num_threads; } + + /** + * @brief Submit the task to the thread pool for execution. The worker thread is selected + * automatically in a round-robin fashion. + * + * @tparam F Type of the task callable. + * @tparam R Return type of the task callable. + * @param task Task callable. The task can either be copyable or move-only. + * @return An std::future object. R can be void or other types. + */ + template >> + [[nodiscard]] std::future submit_task(F&& task) + { + // The call index is atomically incremented on each submit_task call, and will wrap around once + // it reaches the maximum value the integer type `std::size_t` can hold (this overflow + // behavior is well-defined in C++). + auto tid = + std::atomic_fetch_add_explicit(&_task_submission_counter, 1, std::memory_order_relaxed); + tid %= _num_threads; + + return submit_task_to_thread(std::forward(task), tid); + } + + /** + * @brief Submit the task to a specific thread for execution. + * + * @tparam F Type of the task callable. + * @tparam R Return type of the task callable. + * @param task Task callable. The task can either be copyable or move-only. + * @param thread_idx Index of the thread to which the task is submitted. + * @return An std::future object. R can be void or other types. + */ + template >> + [[nodiscard]] std::future submit_task_to_thread(F&& task, std::size_t thread_idx) + { + auto& worker = _workers[thread_idx]; + + std::promise p; + auto fut = p.get_future(); + + { + std::lock_guard lock(worker.task_mutex); + + worker.task_queue.emplace([task = std::forward(task), p = std::move(p)]() mutable { + try { + if constexpr (std::is_same_v) { + task(); + p.set_value(); + } else { + p.set_value(task()); + } + } catch (...) { + p.set_exception(std::current_exception()); + } + }); + + ++worker.tasks_in_flight; + } + + worker.task_available_cv.notify_one(); + return fut; + } + + private: + /** + * @brief Worker thread loop. + * + * @param thread_idx Worker thread index. + */ + void run_worker(std::size_t thread_idx) + { + ThisThread::set_thread_idx(thread_idx); + + auto& worker = _workers[thread_idx]; + + if (_worker_thread_init_func) { std::invoke(_worker_thread_init_func); } + + while (true) { + std::unique_lock lock(worker.task_mutex); + + if (worker.task_queue.empty() && worker.tasks_in_flight == 0) { + worker.task_done_cv.notify_all(); + } + + worker.task_available_cv.wait( + lock, [&] { return !worker.task_queue.empty() || worker.should_stop; }); + + if (worker.should_stop) { break; } + + auto task = std::move(worker.task_queue.front()); + worker.task_queue.pop(); + lock.unlock(); + + task(); + + { + std::lock_guard done_lock(worker.task_mutex); + --worker.tasks_in_flight; + } + } + } + + /** + * @brief Create worker threads. + */ + void create_threads() + { + _workers = std::make_unique(_num_threads); + for (unsigned int thread_idx = 0; thread_idx < _num_threads; ++thread_idx) { + _workers[thread_idx].thread = std::thread([this, thread_idx] { run_worker(thread_idx); }); + } + } + + /** + * @brief Notify each work thread of the intention to stop and join the threads. Pre-condition: + * Each worker thread has finished all the tasks in their task queue. + */ + void destroy_threads() + { + for (unsigned int thread_idx = 0; thread_idx < _num_threads; ++thread_idx) { + auto& worker = _workers[thread_idx]; + + { + std::lock_guard lock(worker.task_mutex); + worker.should_stop = true; + } + + worker.task_available_cv.notify_one(); + + worker.thread.join(); + } + } + + unsigned int _num_threads{}; + detail::FunctionWrapper _worker_thread_init_func; + std::unique_ptr _workers; + std::atomic_size_t _task_submission_counter{0}; +}; + +} // namespace kvikio diff --git a/cpp/include/kvikio/threadpool_wrapper.hpp b/cpp/include/kvikio/threadpool_wrapper.hpp index 0644b8c9ca..d0e8d4b286 100644 --- a/cpp/include/kvikio/threadpool_wrapper.hpp +++ b/cpp/include/kvikio/threadpool_wrapper.hpp @@ -9,24 +9,9 @@ namespace kvikio { -template -class thread_pool_wrapper : public pool_type { - public: - /** - * @brief Construct a new thread pool wrapper. - * - * @param nthreads The number of threads to use. - */ - thread_pool_wrapper(unsigned int nthreads) : pool_type{nthreads} {} - - /** - * @brief Reset the number of threads in the thread pool. - * - * @param nthreads The number of threads to use. - */ - void reset(unsigned int nthreads) { pool_type::reset(nthreads); } -}; - -using BS_thread_pool = thread_pool_wrapper; +/** + * @brief Thread pool type used for parallel I/O operations. + */ +using ThreadPool = BS::thread_pool; } // namespace kvikio diff --git a/cpp/src/defaults.cpp b/cpp/src/defaults.cpp index f827ef6cf5..1bbc151b86 100644 --- a/cpp/src/defaults.cpp +++ b/cpp/src/defaults.cpp @@ -172,7 +172,7 @@ bool defaults::is_compat_mode_preferred(CompatMode compat_mode) noexcept bool defaults::is_compat_mode_preferred() { return is_compat_mode_preferred(compat_mode()); } -BS_thread_pool& defaults::thread_pool() { return instance()->_thread_pool; } +ThreadPool& defaults::thread_pool() { return instance()->_thread_pool; } unsigned int defaults::thread_pool_nthreads() { return thread_pool().get_thread_count(); } diff --git a/cpp/src/file_handle.cpp b/cpp/src/file_handle.cpp index 30f1cf335a..abec24fa79 100644 --- a/cpp/src/file_handle.cpp +++ b/cpp/src/file_handle.cpp @@ -19,6 +19,7 @@ #include #include #include +#include namespace kvikio { @@ -148,8 +149,10 @@ std::future FileHandle::pread(void* buf, std::size_t file_offset, std::size_t task_size, std::size_t gds_threshold, - bool sync_default_stream) + bool sync_default_stream, + ThreadPool* thread_pool) { + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx(); KVIKIO_NVTX_FUNC_RANGE(size, nvtx_color); if (is_host_memory(buf)) { @@ -162,7 +165,7 @@ std::future FileHandle::pread(void* buf, _file_direct_off.fd(), buf, size, file_offset, _file_direct_on.fd()); }; - return parallel_io(op, buf, size, file_offset, task_size, 0, call_idx, nvtx_color); + return parallel_io(op, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color); } CUcontext ctx = get_context_from_pointer(buf); @@ -192,8 +195,15 @@ std::future FileHandle::pread(void* buf, return read(devPtr_base, size, file_offset, devPtr_offset, /* sync_default_stream = */ false); }; auto [devPtr_base, base_size, devPtr_offset] = get_alloc_info(buf, &ctx); - return parallel_io( - task, devPtr_base, size, file_offset, task_size, devPtr_offset, call_idx, nvtx_color); + return parallel_io(task, + devPtr_base, + size, + file_offset, + task_size, + devPtr_offset, + thread_pool, + call_idx, + nvtx_color); } std::future FileHandle::pwrite(void const* buf, @@ -201,8 +211,10 @@ std::future FileHandle::pwrite(void const* buf, std::size_t file_offset, std::size_t task_size, std::size_t gds_threshold, - bool sync_default_stream) + bool sync_default_stream, + ThreadPool* thread_pool) { + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx(); KVIKIO_NVTX_FUNC_RANGE(size, nvtx_color); if (is_host_memory(buf)) { @@ -215,7 +227,7 @@ std::future FileHandle::pwrite(void const* buf, _file_direct_off.fd(), buf, size, file_offset, _file_direct_on.fd()); }; - return parallel_io(op, buf, size, file_offset, task_size, 0, call_idx, nvtx_color); + return parallel_io(op, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color); } CUcontext ctx = get_context_from_pointer(buf); @@ -245,8 +257,15 @@ std::future FileHandle::pwrite(void const* buf, return write(devPtr_base, size, file_offset, devPtr_offset, /* sync_default_stream = */ false); }; auto [devPtr_base, base_size, devPtr_offset] = get_alloc_info(buf, &ctx); - return parallel_io( - op, devPtr_base, size, file_offset, task_size, devPtr_offset, call_idx, nvtx_color); + return parallel_io(op, + devPtr_base, + size, + file_offset, + task_size, + devPtr_offset, + thread_pool, + call_idx, + nvtx_color); } void FileHandle::read_async(void* devPtr_base, diff --git a/cpp/src/mmap.cpp b/cpp/src/mmap.cpp index a720fa8929..ff579cfa4e 100644 --- a/cpp/src/mmap.cpp +++ b/cpp/src/mmap.cpp @@ -415,10 +415,12 @@ std::size_t MmapHandle::read(void* buf, std::optional size, std::si std::future MmapHandle::pread(void* buf, std::optional size, std::size_t offset, - std::size_t task_size) + std::size_t task_size, + ThreadPool* thread_pool) { KVIKIO_EXPECT(task_size <= defaults::bounce_buffer_size(), "bounce buffer size cannot be less than task size."); + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto actual_size = validate_and_adjust_read_args(size, offset); if (actual_size == 0) { return make_ready_future(actual_size); } @@ -448,6 +450,7 @@ std::future MmapHandle::pread(void* buf, offset, task_size, 0, // dst buffer offset initial value + thread_pool, call_idx, nvtx_color); } diff --git a/cpp/src/remote_handle.cpp b/cpp/src/remote_handle.cpp index 6004515b76..7c917c9a0b 100644 --- a/cpp/src/remote_handle.cpp +++ b/cpp/src/remote_handle.cpp @@ -809,8 +809,10 @@ std::size_t RemoteHandle::read(void* buf, std::size_t size, std::size_t file_off std::future RemoteHandle::pread(void* buf, std::size_t size, std::size_t file_offset, - std::size_t task_size) + std::size_t task_size, + ThreadPool* thread_pool) { + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx(); KVIKIO_NVTX_FUNC_RANGE(size); auto task = [this](void* devPtr_base, @@ -819,7 +821,7 @@ std::future RemoteHandle::pread(void* buf, std::size_t devPtr_offset) -> std::size_t { return read(static_cast(devPtr_base) + devPtr_offset, size, file_offset); }; - return parallel_io(task, buf, size, file_offset, task_size, 0, call_idx, nvtx_color); + return parallel_io(task, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color); } } // namespace kvikio diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 2effd6a559..2a289af360 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -69,6 +69,8 @@ kvikio_add_test(NAME ERROR_TEST SOURCES test_error.cpp) kvikio_add_test(NAME MMAP_TEST SOURCES test_mmap.cpp) +kvikio_add_test(NAME ROUNDROBIN_THREADPOOL_TEST SOURCES test_roundrobin_threadpool.cpp) + if(KvikIO_REMOTE_SUPPORT) kvikio_add_test(NAME REMOTE_HANDLE_TEST SOURCES test_remote_handle.cpp utils/env.cpp) kvikio_add_test(NAME HDFS_TEST SOURCES test_hdfs.cpp utils/hdfs_helper.cpp) diff --git a/cpp/tests/test_roundrobin_threadpool.cpp b/cpp/tests/test_roundrobin_threadpool.cpp new file mode 100644 index 0000000000..668007e989 --- /dev/null +++ b/cpp/tests/test_roundrobin_threadpool.cpp @@ -0,0 +1,11 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +TEST(RoundRobinThreadPoolTest, basics) {}