diff --git a/cpp/include/kvikio/defaults.hpp b/cpp/include/kvikio/defaults.hpp index 722986c362..aff8aa5ea8 100644 --- a/cpp/include/kvikio/defaults.hpp +++ b/cpp/include/kvikio/defaults.hpp @@ -111,7 +111,7 @@ std::tuple getenv_or( */ class defaults { private: - BS_thread_pool _thread_pool{get_num_threads_from_env()}; + ThreadPool _thread_pool{get_num_threads_from_env()}; CompatMode _compat_mode; std::size_t _task_size; std::size_t _gds_threshold; @@ -212,7 +212,7 @@ class defaults { * * @return The default thread pool instance. */ - [[nodiscard]] static BS_thread_pool& thread_pool(); + [[nodiscard]] static ThreadPool& thread_pool(); /** * @brief Get the number of threads in the default thread pool. diff --git a/cpp/include/kvikio/detail/parallel_operation.hpp b/cpp/include/kvikio/detail/parallel_operation.hpp index a4489da8e5..1d3c43d287 100644 --- a/cpp/include/kvikio/detail/parallel_operation.hpp +++ b/cpp/include/kvikio/detail/parallel_operation.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace kvikio { @@ -75,6 +76,7 @@ std::future submit_task(F op, std::size_t size, std::size_t file_offset, std::size_t devPtr_offset, + ThreadPool* thread_pool = &defaults::thread_pool(), std::uint64_t nvtx_payload = 0ull, nvtx_color_type nvtx_color = NvtxManager::default_color()) { @@ -85,7 +87,7 @@ std::future submit_task(F op, decltype(file_offset), decltype(devPtr_offset)>); - return defaults::thread_pool().submit_task([=] { + return thread_pool->submit_task([=] { KVIKIO_NVTX_SCOPED_RANGE("task", nvtx_payload, nvtx_color); return op(buf, size, file_offset, devPtr_offset); }); @@ -101,12 +103,13 @@ std::future submit_task(F op, template std::future submit_move_only_task( F op_move_only, + ThreadPool* thread_pool = &defaults::thread_pool(), std::uint64_t nvtx_payload = 0ull, nvtx_color_type nvtx_color = NvtxManager::default_color()) { static_assert(std::is_invocable_r_v); auto op_copyable = make_copyable_lambda(std::move(op_move_only)); - return defaults::thread_pool().submit_task([=] { + return thread_pool->submit_task([=] { KVIKIO_NVTX_SCOPED_RANGE("task", nvtx_payload, nvtx_color); return op_copyable(); }); @@ -124,6 +127,10 @@ std::future submit_move_only_task( * @param size Number of bytes to read or write. * @param file_offset Byte offset to the start of the file. * @param task_size Size of each task in bytes. + * @param devPtr_offset Offset relative to the `devPtr_base` pointer. This parameter should be used + * only with registered buffers. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. * @return A future to be used later to check if the operation has finished its execution. */ template @@ -133,10 +140,12 @@ std::future parallel_io(F op, std::size_t file_offset, std::size_t task_size, std::size_t devPtr_offset, + ThreadPool* thread_pool = &defaults::thread_pool(), std::uint64_t call_idx = 0, nvtx_color_type nvtx_color = NvtxManager::default_color()) { KVIKIO_EXPECT(task_size > 0, "`task_size` must be positive", std::invalid_argument); + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); static_assert(std::is_invocable_r_v parallel_io(F op, // Single-task guard if (task_size >= size || get_page_size() >= size) { - return detail::submit_task(op, buf, size, file_offset, devPtr_offset, call_idx, nvtx_color); + return detail::submit_task( + op, buf, size, file_offset, devPtr_offset, thread_pool, call_idx, nvtx_color); } std::vector> tasks; @@ -154,8 +164,8 @@ std::future parallel_io(F op, // 1) Submit all tasks but the last one. These are all `task_size` sized tasks. while (size > task_size) { - tasks.push_back( - detail::submit_task(op, buf, task_size, file_offset, devPtr_offset, call_idx, nvtx_color)); + tasks.push_back(detail::submit_task( + op, buf, task_size, file_offset, devPtr_offset, thread_pool, call_idx, nvtx_color)); file_offset += task_size; devPtr_offset += task_size; size -= task_size; @@ -170,7 +180,7 @@ std::future parallel_io(F op, } return ret; }; - return detail::submit_move_only_task(std::move(last_task), call_idx, nvtx_color); + return detail::submit_move_only_task(std::move(last_task), thread_pool, call_idx, nvtx_color); } } // namespace kvikio diff --git a/cpp/include/kvikio/file_handle.hpp b/cpp/include/kvikio/file_handle.hpp index fd24e7623b..613a359d1b 100644 --- a/cpp/include/kvikio/file_handle.hpp +++ b/cpp/include/kvikio/file_handle.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace kvikio { @@ -228,17 +229,22 @@ class FileHandle { * in the null stream. When in KvikIO's compatibility mode or when accessing host memory, the * operation is always default stream ordered like the rest of the non-async CUDA API. In this * case, the value of `sync_default_stream` is ignored. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. The caller is responsible for ensuring that the thread pool remains valid until + * the returned future is consumed (i.e., until `get()` or `wait()` is called on it). * @return Future that on completion returns the size of bytes that were successfully read. * - * @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of - * the FileHandle object ends. Otherwise, the behavior is undefined. + * @note The returned `std::future` object must not outlive either the FileHandle or the thread + * pool. Calling `wait()` or `get()` on the future after the FileHandle or thread pool has been + * destroyed results in undefined behavior. */ std::future pread(void* buf, std::size_t size, std::size_t file_offset = 0, std::size_t task_size = defaults::task_size(), std::size_t gds_threshold = defaults::gds_threshold(), - bool sync_default_stream = true); + bool sync_default_stream = true, + ThreadPool* thread_pool = &defaults::thread_pool()); /** * @brief Writes specified bytes from device or host memory into the file in parallel. @@ -265,17 +271,22 @@ class FileHandle { * in the null stream. When in KvikIO's compatibility mode or when accessing host memory, the * operation is always default stream ordered like the rest of the non-async CUDA API. In this * case, the value of `sync_default_stream` is ignored. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. The caller is responsible for ensuring that the thread pool remains valid until + * the returned future is consumed (i.e., until `get()` or `wait()` is called on it). * @return Future that on completion returns the size of bytes that were successfully written. * - * @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of - * the FileHandle object ends. Otherwise, the behavior is undefined. + * @note The returned `std::future` object must not outlive either the FileHandle or the thread + * pool. Calling `wait()` or `get()` on the future after the FileHandle or thread pool has been + * destroyed results in undefined behavior. */ std::future pwrite(void const* buf, std::size_t size, std::size_t file_offset = 0, std::size_t task_size = defaults::task_size(), std::size_t gds_threshold = defaults::gds_threshold(), - bool sync_default_stream = true); + bool sync_default_stream = true, + ThreadPool* thread_pool = &defaults::thread_pool()); /** * @brief Reads specified bytes from the file into the device memory asynchronously. diff --git a/cpp/include/kvikio/mmap.hpp b/cpp/include/kvikio/mmap.hpp index fe8b71cbf4..da6f596a11 100644 --- a/cpp/include/kvikio/mmap.hpp +++ b/cpp/include/kvikio/mmap.hpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace kvikio { @@ -162,19 +163,24 @@ class MmapHandle { * specified, read starts from `offset` to the end of file * @param offset File offset * @param task_size Size of each task in bytes + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. The caller is responsible for ensuring that the thread pool remains valid until + * the returned future is consumed (i.e., until `get()` or `wait()` is called on it). * @return Future that on completion returns the size of bytes that were successfully read. * * @exception std::out_of_range if the read region specified by `offset` and `size` is * outside the initial region specified when the mapping handle was constructed * @exception std::runtime_error if the mapping handle is closed * - * @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of - * the MmapHandle object ends. Otherwise, the behavior is undefined. + * @note The returned `std::future` object must not outlive either the MmapHandle or the thread + * pool. Calling `wait()` or `get()` on the future after the MmapHandle or thread pool has been + * destroyed results in undefined behavior. */ std::future pread(void* buf, std::optional size = std::nullopt, std::size_t offset = 0, - std::size_t task_size = defaults::task_size()); + std::size_t task_size = defaults::task_size(), + ThreadPool* thread_pool = &defaults::thread_pool()); }; } // namespace kvikio diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 416e374291..0b0808c45e 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -13,6 +13,7 @@ #include #include +#include #include struct curl_slist; @@ -452,12 +453,20 @@ class RemoteHandle { * @param size Number of bytes to read. * @param file_offset File offset in bytes. * @param task_size Size of each task in bytes. + * @param thread_pool Thread pool to use for parallel execution. Defaults to the global default + * thread pool. The caller is responsible for ensuring that the thread pool remains valid until + * the returned future is consumed (i.e., until `get()` or `wait()` is called on it). * @return Future that on completion returns the size of bytes read, which is always `size`. + * + * @note The returned `std::future` object must not outlive either the RemoteHandle or the thread + * pool. Calling `wait()` or `get()` on the future after the RemoteHandle or thread pool has been + * destroyed results in undefined behavior. */ std::future pread(void* buf, std::size_t size, std::size_t file_offset = 0, - std::size_t task_size = defaults::task_size()); + std::size_t task_size = defaults::task_size(), + ThreadPool* thread_pool = &defaults::thread_pool()); }; } // namespace kvikio diff --git a/cpp/include/kvikio/threadpool_wrapper.hpp b/cpp/include/kvikio/threadpool_wrapper.hpp index 0644b8c9ca..d0e8d4b286 100644 --- a/cpp/include/kvikio/threadpool_wrapper.hpp +++ b/cpp/include/kvikio/threadpool_wrapper.hpp @@ -9,24 +9,9 @@ namespace kvikio { -template -class thread_pool_wrapper : public pool_type { - public: - /** - * @brief Construct a new thread pool wrapper. - * - * @param nthreads The number of threads to use. - */ - thread_pool_wrapper(unsigned int nthreads) : pool_type{nthreads} {} - - /** - * @brief Reset the number of threads in the thread pool. - * - * @param nthreads The number of threads to use. - */ - void reset(unsigned int nthreads) { pool_type::reset(nthreads); } -}; - -using BS_thread_pool = thread_pool_wrapper; +/** + * @brief Thread pool type used for parallel I/O operations. + */ +using ThreadPool = BS::thread_pool; } // namespace kvikio diff --git a/cpp/src/defaults.cpp b/cpp/src/defaults.cpp index f827ef6cf5..1bbc151b86 100644 --- a/cpp/src/defaults.cpp +++ b/cpp/src/defaults.cpp @@ -172,7 +172,7 @@ bool defaults::is_compat_mode_preferred(CompatMode compat_mode) noexcept bool defaults::is_compat_mode_preferred() { return is_compat_mode_preferred(compat_mode()); } -BS_thread_pool& defaults::thread_pool() { return instance()->_thread_pool; } +ThreadPool& defaults::thread_pool() { return instance()->_thread_pool; } unsigned int defaults::thread_pool_nthreads() { return thread_pool().get_thread_count(); } diff --git a/cpp/src/file_handle.cpp b/cpp/src/file_handle.cpp index 30f1cf335a..abec24fa79 100644 --- a/cpp/src/file_handle.cpp +++ b/cpp/src/file_handle.cpp @@ -19,6 +19,7 @@ #include #include #include +#include namespace kvikio { @@ -148,8 +149,10 @@ std::future FileHandle::pread(void* buf, std::size_t file_offset, std::size_t task_size, std::size_t gds_threshold, - bool sync_default_stream) + bool sync_default_stream, + ThreadPool* thread_pool) { + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx(); KVIKIO_NVTX_FUNC_RANGE(size, nvtx_color); if (is_host_memory(buf)) { @@ -162,7 +165,7 @@ std::future FileHandle::pread(void* buf, _file_direct_off.fd(), buf, size, file_offset, _file_direct_on.fd()); }; - return parallel_io(op, buf, size, file_offset, task_size, 0, call_idx, nvtx_color); + return parallel_io(op, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color); } CUcontext ctx = get_context_from_pointer(buf); @@ -192,8 +195,15 @@ std::future FileHandle::pread(void* buf, return read(devPtr_base, size, file_offset, devPtr_offset, /* sync_default_stream = */ false); }; auto [devPtr_base, base_size, devPtr_offset] = get_alloc_info(buf, &ctx); - return parallel_io( - task, devPtr_base, size, file_offset, task_size, devPtr_offset, call_idx, nvtx_color); + return parallel_io(task, + devPtr_base, + size, + file_offset, + task_size, + devPtr_offset, + thread_pool, + call_idx, + nvtx_color); } std::future FileHandle::pwrite(void const* buf, @@ -201,8 +211,10 @@ std::future FileHandle::pwrite(void const* buf, std::size_t file_offset, std::size_t task_size, std::size_t gds_threshold, - bool sync_default_stream) + bool sync_default_stream, + ThreadPool* thread_pool) { + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx(); KVIKIO_NVTX_FUNC_RANGE(size, nvtx_color); if (is_host_memory(buf)) { @@ -215,7 +227,7 @@ std::future FileHandle::pwrite(void const* buf, _file_direct_off.fd(), buf, size, file_offset, _file_direct_on.fd()); }; - return parallel_io(op, buf, size, file_offset, task_size, 0, call_idx, nvtx_color); + return parallel_io(op, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color); } CUcontext ctx = get_context_from_pointer(buf); @@ -245,8 +257,15 @@ std::future FileHandle::pwrite(void const* buf, return write(devPtr_base, size, file_offset, devPtr_offset, /* sync_default_stream = */ false); }; auto [devPtr_base, base_size, devPtr_offset] = get_alloc_info(buf, &ctx); - return parallel_io( - op, devPtr_base, size, file_offset, task_size, devPtr_offset, call_idx, nvtx_color); + return parallel_io(op, + devPtr_base, + size, + file_offset, + task_size, + devPtr_offset, + thread_pool, + call_idx, + nvtx_color); } void FileHandle::read_async(void* devPtr_base, diff --git a/cpp/src/mmap.cpp b/cpp/src/mmap.cpp index a720fa8929..ff579cfa4e 100644 --- a/cpp/src/mmap.cpp +++ b/cpp/src/mmap.cpp @@ -415,10 +415,12 @@ std::size_t MmapHandle::read(void* buf, std::optional size, std::si std::future MmapHandle::pread(void* buf, std::optional size, std::size_t offset, - std::size_t task_size) + std::size_t task_size, + ThreadPool* thread_pool) { KVIKIO_EXPECT(task_size <= defaults::bounce_buffer_size(), "bounce buffer size cannot be less than task size."); + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto actual_size = validate_and_adjust_read_args(size, offset); if (actual_size == 0) { return make_ready_future(actual_size); } @@ -448,6 +450,7 @@ std::future MmapHandle::pread(void* buf, offset, task_size, 0, // dst buffer offset initial value + thread_pool, call_idx, nvtx_color); } diff --git a/cpp/src/remote_handle.cpp b/cpp/src/remote_handle.cpp index 6004515b76..7c917c9a0b 100644 --- a/cpp/src/remote_handle.cpp +++ b/cpp/src/remote_handle.cpp @@ -809,8 +809,10 @@ std::size_t RemoteHandle::read(void* buf, std::size_t size, std::size_t file_off std::future RemoteHandle::pread(void* buf, std::size_t size, std::size_t file_offset, - std::size_t task_size) + std::size_t task_size, + ThreadPool* thread_pool) { + KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx(); KVIKIO_NVTX_FUNC_RANGE(size); auto task = [this](void* devPtr_base, @@ -819,7 +821,7 @@ std::future RemoteHandle::pread(void* buf, std::size_t devPtr_offset) -> std::size_t { return read(static_cast(devPtr_base) + devPtr_offset, size, file_offset); }; - return parallel_io(task, buf, size, file_offset, task_size, 0, call_idx, nvtx_color); + return parallel_io(task, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color); } } // namespace kvikio diff --git a/cpp/tests/test_basic_io.cpp b/cpp/tests/test_basic_io.cpp index aeda7051d0..113d677e95 100644 --- a/cpp/tests/test_basic_io.cpp +++ b/cpp/tests/test_basic_io.cpp @@ -7,12 +7,14 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include "utils/env.hpp" @@ -27,7 +29,7 @@ class BasicIOTest : public testing::Test { TempDir tmp_dir{false}; _filepath = tmp_dir.path() / "test"; - _dev_a = std::move(DevBuffer::arange(100)); + _dev_a = std::move(DevBuffer::arange(1024ULL * 1024ULL + 124ULL)); _dev_b = std::move(DevBuffer::zero_like(_dev_a)); } @@ -98,6 +100,55 @@ TEST_F(BasicIOTest, write_read_async) CUDA_DRIVER_TRY(kvikio::cudaAPI::instance().StreamDestroy(stream)); } +TEST_F(BasicIOTest, threadpool) +{ + auto thread_pool = std::make_unique(4); + + // Write to a file using an external thread pool + { + kvikio::FileHandle f(_filepath, "w"); + auto fut = f.pwrite(_dev_a.ptr, + _dev_a.nbytes, // size + 0, // file_offset + kvikio::defaults::task_size(), + kvikio::defaults::gds_threshold(), + true, + thread_pool.get()); + auto nbytes_written = fut.get(); + EXPECT_EQ(nbytes_written, _dev_a.nbytes); + } + + // Read from the file using an external thread pool + { + std::vector> futs; + std::vector filepaths{_filepath, _filepath}; + std::vector file_handles; + std::vector> dev_buffers; + + for (auto const& filepath : filepaths) { + file_handles.emplace_back(filepath, "r"); + dev_buffers.push_back(DevBuffer::zero_like(_dev_a)); + } + + for (std::size_t i = 0; i < file_handles.size(); ++i) { + auto fut = file_handles[i].pread(dev_buffers[i].ptr, + dev_buffers[i].nbytes, + 0, + kvikio::defaults::task_size(), + kvikio::defaults::gds_threshold(), + true, + thread_pool.get()); + futs.push_back(std::move(fut)); + } + + for (std::size_t i = 0; i < file_handles.size(); ++i) { + auto nbtyes_read = futs[i].get(); + EXPECT_EQ(nbtyes_read, _dev_a.nbytes); + expect_equal(_dev_a, dev_buffers[i]); + } + } +} + class DirectIOTest : public testing::Test { public: using value_type = std::int64_t; diff --git a/cpp/tests/test_mmap.cpp b/cpp/tests/test_mmap.cpp index 9e355f4789..1d280d6594 100644 --- a/cpp/tests/test_mmap.cpp +++ b/cpp/tests/test_mmap.cpp @@ -360,3 +360,37 @@ TEST_F(MmapTest, cpp_move) do_test(mmap_handle_2); } } + +TEST_F(MmapTest, threadpool) +{ + auto thread_pool = std::make_unique(4); + + // Read from the file using an external thread pool + { + std::size_t num_elements = _file_size / sizeof(value_type); + std::vector> futs; + std::vector filepaths{_filepath, _filepath}; + std::vector mmap_handles; + std::vector> dev_buffers; + + for (auto const& filepath : filepaths) { + mmap_handles.emplace_back(filepath, "r"); + dev_buffers.push_back(kvikio::test::DevBuffer::zero_like(num_elements)); + } + + for (std::size_t i = 0; i < mmap_handles.size(); ++i) { + auto fut = mmap_handles[i].pread(dev_buffers[i].ptr, + dev_buffers[i].nbytes, + 0, + kvikio::defaults::task_size(), + thread_pool.get()); + futs.push_back(std::move(fut)); + } + + for (std::size_t i = 0; i < mmap_handles.size(); ++i) { + auto nbtyes_read = futs[i].get(); + EXPECT_EQ(nbtyes_read, _file_size); + EXPECT_EQ(_host_buf, dev_buffers[i].to_vector()); + } + } +}