Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/include/kvikio/defaults.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ std::tuple<std::string_view, T, bool> getenv_or(
*/
class defaults {
private:
BS_thread_pool _thread_pool{get_num_threads_from_env()};
ThreadPool _thread_pool{get_num_threads_from_env()};
CompatMode _compat_mode;
std::size_t _task_size;
std::size_t _gds_threshold;
Expand Down Expand Up @@ -212,7 +212,7 @@ class defaults {
*
* @return The default thread pool instance.
*/
[[nodiscard]] static BS_thread_pool& thread_pool();
[[nodiscard]] static ThreadPool& thread_pool();

/**
* @brief Get the number of threads in the default thread pool.
Expand Down
22 changes: 16 additions & 6 deletions cpp/include/kvikio/detail/parallel_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <kvikio/defaults.hpp>
#include <kvikio/detail/nvtx.hpp>
#include <kvikio/error.hpp>
#include <kvikio/threadpool_wrapper.hpp>
#include <kvikio/utils.hpp>

namespace kvikio {
Expand Down Expand Up @@ -75,6 +76,7 @@ std::future<std::size_t> submit_task(F op,
std::size_t size,
std::size_t file_offset,
std::size_t devPtr_offset,
ThreadPool* thread_pool = &defaults::thread_pool(),
std::uint64_t nvtx_payload = 0ull,
nvtx_color_type nvtx_color = NvtxManager::default_color())
{
Expand All @@ -85,7 +87,7 @@ std::future<std::size_t> submit_task(F op,
decltype(file_offset),
decltype(devPtr_offset)>);

return defaults::thread_pool().submit_task([=] {
return thread_pool->submit_task([=] {
KVIKIO_NVTX_SCOPED_RANGE("task", nvtx_payload, nvtx_color);
return op(buf, size, file_offset, devPtr_offset);
});
Expand All @@ -101,12 +103,13 @@ std::future<std::size_t> submit_task(F op,
template <typename F>
std::future<std::size_t> submit_move_only_task(
F op_move_only,
ThreadPool* thread_pool = &defaults::thread_pool(),
std::uint64_t nvtx_payload = 0ull,
nvtx_color_type nvtx_color = NvtxManager::default_color())
{
static_assert(std::is_invocable_r_v<std::size_t, F>);
auto op_copyable = make_copyable_lambda(std::move(op_move_only));
return defaults::thread_pool().submit_task([=] {
return thread_pool->submit_task([=] {
KVIKIO_NVTX_SCOPED_RANGE("task", nvtx_payload, nvtx_color);
return op_copyable();
});
Expand All @@ -124,6 +127,10 @@ std::future<std::size_t> submit_move_only_task(
* @param size Number of bytes to read or write.
* @param file_offset Byte offset to the start of the file.
* @param task_size Size of each task in bytes.
* @param devPtr_offset Offset relative to the `devPtr_base` pointer. This parameter should be used
* only with registered buffers.
* @param thread_pool Thread pool to use for parallel execution. Defaults to the global default
* thread pool.
* @return A future to be used later to check if the operation has finished its execution.
*/
template <typename F, typename T>
Expand All @@ -133,10 +140,12 @@ std::future<std::size_t> parallel_io(F op,
std::size_t file_offset,
std::size_t task_size,
std::size_t devPtr_offset,
ThreadPool* thread_pool = &defaults::thread_pool(),
std::uint64_t call_idx = 0,
nvtx_color_type nvtx_color = NvtxManager::default_color())
{
KVIKIO_EXPECT(task_size > 0, "`task_size` must be positive", std::invalid_argument);
KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr");
static_assert(std::is_invocable_r_v<std::size_t,
decltype(op),
decltype(buf),
Expand All @@ -146,16 +155,17 @@ std::future<std::size_t> parallel_io(F op,

// Single-task guard
if (task_size >= size || get_page_size() >= size) {
return detail::submit_task(op, buf, size, file_offset, devPtr_offset, call_idx, nvtx_color);
return detail::submit_task(
op, buf, size, file_offset, devPtr_offset, thread_pool, call_idx, nvtx_color);
}

std::vector<std::future<std::size_t>> tasks;
tasks.reserve(size / task_size);

// 1) Submit all tasks but the last one. These are all `task_size` sized tasks.
while (size > task_size) {
tasks.push_back(
detail::submit_task(op, buf, task_size, file_offset, devPtr_offset, call_idx, nvtx_color));
tasks.push_back(detail::submit_task(
op, buf, task_size, file_offset, devPtr_offset, thread_pool, call_idx, nvtx_color));
file_offset += task_size;
devPtr_offset += task_size;
size -= task_size;
Expand All @@ -170,7 +180,7 @@ std::future<std::size_t> parallel_io(F op,
}
return ret;
};
return detail::submit_move_only_task(std::move(last_task), call_idx, nvtx_color);
return detail::submit_move_only_task(std::move(last_task), thread_pool, call_idx, nvtx_color);
}

} // namespace kvikio
23 changes: 17 additions & 6 deletions cpp/include/kvikio/file_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <kvikio/shim/cufile.hpp>
#include <kvikio/shim/cufile_h_wrapper.hpp>
#include <kvikio/stream.hpp>
#include <kvikio/threadpool_wrapper.hpp>
#include <kvikio/utils.hpp>

namespace kvikio {
Expand Down Expand Up @@ -228,17 +229,22 @@ class FileHandle {
* in the null stream. When in KvikIO's compatibility mode or when accessing host memory, the
* operation is always default stream ordered like the rest of the non-async CUDA API. In this
* case, the value of `sync_default_stream` is ignored.
* @param thread_pool Thread pool to use for parallel execution. Defaults to the global default
* thread pool. The caller is responsible for ensuring that the thread pool remains valid until
* the returned future is consumed (i.e., until `get()` or `wait()` is called on it).
* @return Future that on completion returns the size of bytes that were successfully read.
*
* @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of
* the FileHandle object ends. Otherwise, the behavior is undefined.
* @note The returned `std::future` object must not outlive either the FileHandle or the thread
* pool. Calling `wait()` or `get()` on the future after the FileHandle or thread pool has been
* destroyed results in undefined behavior.
*/
std::future<std::size_t> pread(void* buf,
std::size_t size,
std::size_t file_offset = 0,
std::size_t task_size = defaults::task_size(),
std::size_t gds_threshold = defaults::gds_threshold(),
bool sync_default_stream = true);
bool sync_default_stream = true,
ThreadPool* thread_pool = &defaults::thread_pool());

/**
* @brief Writes specified bytes from device or host memory into the file in parallel.
Expand All @@ -265,17 +271,22 @@ class FileHandle {
* in the null stream. When in KvikIO's compatibility mode or when accessing host memory, the
* operation is always default stream ordered like the rest of the non-async CUDA API. In this
* case, the value of `sync_default_stream` is ignored.
* @param thread_pool Thread pool to use for parallel execution. Defaults to the global default
* thread pool. The caller is responsible for ensuring that the thread pool remains valid until
* the returned future is consumed (i.e., until `get()` or `wait()` is called on it).
* @return Future that on completion returns the size of bytes that were successfully written.
*
* @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of
* the FileHandle object ends. Otherwise, the behavior is undefined.
* @note The returned `std::future` object must not outlive either the FileHandle or the thread
* pool. Calling `wait()` or `get()` on the future after the FileHandle or thread pool has been
* destroyed results in undefined behavior.
*/
std::future<std::size_t> pwrite(void const* buf,
std::size_t size,
std::size_t file_offset = 0,
std::size_t task_size = defaults::task_size(),
std::size_t gds_threshold = defaults::gds_threshold(),
bool sync_default_stream = true);
bool sync_default_stream = true,
ThreadPool* thread_pool = &defaults::thread_pool());

/**
* @brief Reads specified bytes from the file into the device memory asynchronously.
Expand Down
12 changes: 9 additions & 3 deletions cpp/include/kvikio/mmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <kvikio/defaults.hpp>
#include <kvikio/file_handle.hpp>
#include <kvikio/threadpool_wrapper.hpp>
#include <optional>

namespace kvikio {
Expand Down Expand Up @@ -162,19 +163,24 @@ class MmapHandle {
* specified, read starts from `offset` to the end of file
* @param offset File offset
* @param task_size Size of each task in bytes
* @param thread_pool Thread pool to use for parallel execution. Defaults to the global default
* thread pool. The caller is responsible for ensuring that the thread pool remains valid until
* the returned future is consumed (i.e., until `get()` or `wait()` is called on it).
* @return Future that on completion returns the size of bytes that were successfully read.
*
* @exception std::out_of_range if the read region specified by `offset` and `size` is
* outside the initial region specified when the mapping handle was constructed
* @exception std::runtime_error if the mapping handle is closed
*
* @note The `std::future` object's `wait()` or `get()` should not be called after the lifetime of
* the MmapHandle object ends. Otherwise, the behavior is undefined.
* @note The returned `std::future` object must not outlive either the MmapHandle or the thread
* pool. Calling `wait()` or `get()` on the future after the MmapHandle or thread pool has been
* destroyed results in undefined behavior.
*/
std::future<std::size_t> pread(void* buf,
std::optional<std::size_t> size = std::nullopt,
std::size_t offset = 0,
std::size_t task_size = defaults::task_size());
std::size_t task_size = defaults::task_size(),
ThreadPool* thread_pool = &defaults::thread_pool());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using std::shared_ptr<ThreadPool> throughout to encourage easier and safer lifetime management for users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Thinking over passing thread pool as a shared pointer, when the asynchronous function pread/pwrite returns, the shared pointer is destroyed. So in order to properly extend the pool's lifetime for the async operation and prevent use-after-free, we need to further share its ownership with the I/O task, either each task or the last aggregate task. The pro is no concern over thread pool lifetime at the point the std::future 's result is being waited for. The con is the slight increase in implementation complexity and runtime overhead.

If we pass a raw pointer instead, we claim no ownership responsibility and require users to maintain the pool's lifetime throughout the I/O operations. The pro is simplicity, and the con is loss of bonus of smart pointers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appears to be a very tricky problem.

To extend the lifetime of the thread pool properly during the asynchronous operations, we need to capture std::shared_ptr<ThreadPool> in the last task:
https://github.com/rapidsai/kvikio/blob/main/cpp/include/kvikio/detail/parallel_operation.hpp#L166

auto last_task = [=, thread_pool = thread_pool, tasks = std::move(tasks)]() mutable -> std::size_t {

Suppose reference count is exactly 1 when the last task is being executed. When it is done, the task goes out of scope precisely at https://github.com/bshoshany/thread-pool/blob/v4.1.0/include/BS_thread_pool.hpp#L938, and the reference count will reach 0 and the pool start being destroyed. In the destructor, we wait (sleep) (https://github.com/bshoshany/thread-pool/blob/v4.1.0/include/BS_thread_pool.hpp#L336) for the condition that tasks_running == 0, which will not happen because --tasks_running takes place at the beginning of the worker's loop (https://github.com/bshoshany/thread-pool/blob/v4.1.0/include/BS_thread_pool.hpp#L915). So tasks_running will always be 1 and we are waiting forever in the destructor. Strangely, I haven't seen this in my unit test, but I fear that the deadlock may appear in production.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think if we do want to extend the lifetime of the thread pool, we need to add it directly to the returned future's results ("shared state" in C++ terminology), i.e. instead of std::future<std::size_t> we probably need std::future<std::pair<std::size_t, std::shared_ptr<ThreadPool>>>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, I'm inclined to go back to the raw pointer approach, and ask users to shoulder the responsibility of lifetime management for the thread pool. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

};

} // namespace kvikio
11 changes: 10 additions & 1 deletion cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <kvikio/defaults.hpp>
#include <kvikio/error.hpp>
#include <kvikio/threadpool_wrapper.hpp>
#include <kvikio/utils.hpp>

struct curl_slist;
Expand Down Expand Up @@ -452,12 +453,20 @@ class RemoteHandle {
* @param size Number of bytes to read.
* @param file_offset File offset in bytes.
* @param task_size Size of each task in bytes.
* @param thread_pool Thread pool to use for parallel execution. Defaults to the global default
* thread pool. The caller is responsible for ensuring that the thread pool remains valid until
* the returned future is consumed (i.e., until `get()` or `wait()` is called on it).
* @return Future that on completion returns the size of bytes read, which is always `size`.
*
* @note The returned `std::future` object must not outlive either the RemoteHandle or the thread
* pool. Calling `wait()` or `get()` on the future after the RemoteHandle or thread pool has been
* destroyed results in undefined behavior.
*/
std::future<std::size_t> pread(void* buf,
std::size_t size,
std::size_t file_offset = 0,
std::size_t task_size = defaults::task_size());
std::size_t task_size = defaults::task_size(),
ThreadPool* thread_pool = &defaults::thread_pool());
};

} // namespace kvikio
23 changes: 4 additions & 19 deletions cpp/include/kvikio/threadpool_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,9 @@

namespace kvikio {

template <typename pool_type>
class thread_pool_wrapper : public pool_type {
public:
/**
* @brief Construct a new thread pool wrapper.
*
* @param nthreads The number of threads to use.
*/
thread_pool_wrapper(unsigned int nthreads) : pool_type{nthreads} {}

/**
* @brief Reset the number of threads in the thread pool.
*
* @param nthreads The number of threads to use.
*/
void reset(unsigned int nthreads) { pool_type::reset(nthreads); }
};

using BS_thread_pool = thread_pool_wrapper<BS::thread_pool>;
/**
* @brief Thread pool type used for parallel I/O operations.
*/
using ThreadPool = BS::thread_pool;

} // namespace kvikio
2 changes: 1 addition & 1 deletion cpp/src/defaults.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ bool defaults::is_compat_mode_preferred(CompatMode compat_mode) noexcept

bool defaults::is_compat_mode_preferred() { return is_compat_mode_preferred(compat_mode()); }

BS_thread_pool& defaults::thread_pool() { return instance()->_thread_pool; }
ThreadPool& defaults::thread_pool() { return instance()->_thread_pool; }

unsigned int defaults::thread_pool_nthreads() { return thread_pool().get_thread_count(); }

Expand Down
35 changes: 27 additions & 8 deletions cpp/src/file_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <kvikio/error.hpp>
#include <kvikio/file_handle.hpp>
#include <kvikio/file_utils.hpp>
#include <kvikio/threadpool_wrapper.hpp>

namespace kvikio {

Expand Down Expand Up @@ -148,8 +149,10 @@ std::future<std::size_t> FileHandle::pread(void* buf,
std::size_t file_offset,
std::size_t task_size,
std::size_t gds_threshold,
bool sync_default_stream)
bool sync_default_stream,
ThreadPool* thread_pool)
{
KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr");
auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx();
KVIKIO_NVTX_FUNC_RANGE(size, nvtx_color);
if (is_host_memory(buf)) {
Expand All @@ -162,7 +165,7 @@ std::future<std::size_t> FileHandle::pread(void* buf,
_file_direct_off.fd(), buf, size, file_offset, _file_direct_on.fd());
};

return parallel_io(op, buf, size, file_offset, task_size, 0, call_idx, nvtx_color);
return parallel_io(op, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color);
}

CUcontext ctx = get_context_from_pointer(buf);
Expand Down Expand Up @@ -192,17 +195,26 @@ std::future<std::size_t> FileHandle::pread(void* buf,
return read(devPtr_base, size, file_offset, devPtr_offset, /* sync_default_stream = */ false);
};
auto [devPtr_base, base_size, devPtr_offset] = get_alloc_info(buf, &ctx);
return parallel_io(
task, devPtr_base, size, file_offset, task_size, devPtr_offset, call_idx, nvtx_color);
return parallel_io(task,
devPtr_base,
size,
file_offset,
task_size,
devPtr_offset,
thread_pool,
call_idx,
nvtx_color);
}

std::future<std::size_t> FileHandle::pwrite(void const* buf,
std::size_t size,
std::size_t file_offset,
std::size_t task_size,
std::size_t gds_threshold,
bool sync_default_stream)
bool sync_default_stream,
ThreadPool* thread_pool)
{
KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr");
auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx();
KVIKIO_NVTX_FUNC_RANGE(size, nvtx_color);
if (is_host_memory(buf)) {
Expand All @@ -215,7 +227,7 @@ std::future<std::size_t> FileHandle::pwrite(void const* buf,
_file_direct_off.fd(), buf, size, file_offset, _file_direct_on.fd());
};

return parallel_io(op, buf, size, file_offset, task_size, 0, call_idx, nvtx_color);
return parallel_io(op, buf, size, file_offset, task_size, 0, thread_pool, call_idx, nvtx_color);
}

CUcontext ctx = get_context_from_pointer(buf);
Expand Down Expand Up @@ -245,8 +257,15 @@ std::future<std::size_t> FileHandle::pwrite(void const* buf,
return write(devPtr_base, size, file_offset, devPtr_offset, /* sync_default_stream = */ false);
};
auto [devPtr_base, base_size, devPtr_offset] = get_alloc_info(buf, &ctx);
return parallel_io(
op, devPtr_base, size, file_offset, task_size, devPtr_offset, call_idx, nvtx_color);
return parallel_io(op,
devPtr_base,
size,
file_offset,
task_size,
devPtr_offset,
thread_pool,
call_idx,
nvtx_color);
}

void FileHandle::read_async(void* devPtr_base,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/mmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,12 @@ std::size_t MmapHandle::read(void* buf, std::optional<std::size_t> size, std::si
std::future<std::size_t> MmapHandle::pread(void* buf,
std::optional<std::size_t> size,
std::size_t offset,
std::size_t task_size)
std::size_t task_size,
ThreadPool* thread_pool)
{
KVIKIO_EXPECT(task_size <= defaults::bounce_buffer_size(),
"bounce buffer size cannot be less than task size.");
KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr");
auto actual_size = validate_and_adjust_read_args(size, offset);
if (actual_size == 0) { return make_ready_future(actual_size); }

Expand Down Expand Up @@ -448,6 +450,7 @@ std::future<std::size_t> MmapHandle::pread(void* buf,
offset,
task_size,
0, // dst buffer offset initial value
thread_pool,
call_idx,
nvtx_color);
}
Expand Down
Loading