From dfc286ab82d95b284885c09eef5140e34d3c33eb Mon Sep 17 00:00:00 2001 From: tonydp03 Date: Wed, 11 May 2022 16:02:42 +0200 Subject: [PATCH] Add CallbackThread to QueueUniformCudaHipRt --- include/alpaka/core/CallbackThread.hpp | 78 +++++++++++++++++++ .../queue/cuda_hip/QueueUniformCudaHipRt.hpp | 19 +++-- 2 files changed, 90 insertions(+), 7 deletions(-) create mode 100644 include/alpaka/core/CallbackThread.hpp diff --git a/include/alpaka/core/CallbackThread.hpp b/include/alpaka/core/CallbackThread.hpp new file mode 100644 index 000000000000..a791cd174409 --- /dev/null +++ b/include/alpaka/core/CallbackThread.hpp @@ -0,0 +1,78 @@ +/* Copyright 2022 Antonio Di Pilato + * + * This file is part of alpaka. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace alpaka::core +{ + class CallbackThread + { + using Task = std::packaged_task; + + public: + ~CallbackThread() + { + m_stop = true; + m_cond.notify_one(); + if(m_thread.joinable()) + m_thread.join(); + } + auto submit(Task&& newTask) -> std::future + { + auto f = newTask.get_future(); + { + std::unique_lock lock{m_mutex}; + m_tasks.emplace(std::move(newTask)); + if(!m_thread.joinable()) + startWorkerThread(); + } + m_cond.notify_one(); + return f; + } + + private: + std::thread m_thread; + std::condition_variable m_cond; + std::mutex m_mutex; + std::atomic m_stop{false}; + std::queue m_tasks; + + auto startWorkerThread() -> void + { + m_thread = std::thread( + [this] + { + Task task; + while(true) + { + { + std::unique_lock lock{m_mutex}; + m_cond.wait(lock, [this] { return m_stop || !m_tasks.empty(); }); + + if(m_stop && m_tasks.empty()) + break; + + task = std::move(m_tasks.front()); + m_tasks.pop(); + } + + task(); + } + }); + } + }; +} // namespace alpaka::core diff --git a/include/alpaka/queue/cuda_hip/QueueUniformCudaHipRt.hpp b/include/alpaka/queue/cuda_hip/QueueUniformCudaHipRt.hpp index d1dcdf8c6e0a..ec492ef62de2 100644 --- a/include/alpaka/queue/cuda_hip/QueueUniformCudaHipRt.hpp +++ b/include/alpaka/queue/cuda_hip/QueueUniformCudaHipRt.hpp @@ -27,8 +27,11 @@ # include # endif +# include + # include # include +# include # include # include # include @@ -85,6 +88,8 @@ namespace alpaka public: DevUniformCudaHipRt const m_dev; //!< The device this queue is bound to. + core::CallbackThread m_callbackThread; + private: typename TApi::Stream_t m_UniformCudaHipQueue; }; @@ -114,6 +119,10 @@ namespace alpaka { return m_spQueueImpl->getNativeHandle(); } + auto getCallbackThread() -> core::CallbackThread& + { + return m_spQueueImpl->m_callbackThread; + } public: std::shared_ptr> m_spQueueImpl; @@ -241,7 +250,7 @@ namespace alpaka // callback thread. The CUDA/HIP thread signals the std::thread when it is ready to execute the task. // The CUDA/HIP thread is waiting for the std::thread to signal that it is finished executing the task // before it executes the next task in the queue (CUDA/HIP stream). - std::thread t( + auto f = queue.getCallbackThread().submit(std::packaged_task( [spCallbackSynchronizationData, task]() { // If the callback has not yet been called, we wait for it. @@ -261,15 +270,11 @@ namespace alpaka spCallbackSynchronizationData->m_state = CallbackState::finished; } spCallbackSynchronizationData->m_event.notify_one(); - }); + })); if constexpr(TBlocking) { - t.join(); - } - else - { - t.detach(); + f.wait(); } } };