Skip to content

Commit

Permalink
Add CallbackThread to QueueUniformCudaHipRt
Browse files Browse the repository at this point in the history
  • Loading branch information
tonydp03 authored and j-stephan committed May 13, 2022
1 parent ee309fc commit dfc286a
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
78 changes: 78 additions & 0 deletions include/alpaka/core/CallbackThread.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright 2022 Antonio Di Pilato
*
* This file is part of alpaka.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/

#pragma once

#include <atomic>
#include <condition_variable>
#include <functional>
#include <future>
#include <mutex>
#include <queue>
#include <thread>

namespace alpaka::core
{
class CallbackThread
{
using Task = std::packaged_task<void()>;

public:
~CallbackThread()
{
m_stop = true;
m_cond.notify_one();
if(m_thread.joinable())
m_thread.join();
}
auto submit(Task&& newTask) -> std::future<void>
{
auto f = newTask.get_future();
{
std::unique_lock<std::mutex> lock{m_mutex};
m_tasks.emplace(std::move(newTask));
if(!m_thread.joinable())
startWorkerThread();
}
m_cond.notify_one();
return f;
}

private:
std::thread m_thread;
std::condition_variable m_cond;
std::mutex m_mutex;
std::atomic<bool> m_stop{false};
std::queue<Task> m_tasks;

auto startWorkerThread() -> void
{
m_thread = std::thread(
[this]
{
Task task;
while(true)
{
{
std::unique_lock<std::mutex> lock{m_mutex};
m_cond.wait(lock, [this] { return m_stop || !m_tasks.empty(); });

if(m_stop && m_tasks.empty())
break;

task = std::move(m_tasks.front());
m_tasks.pop();
}

task();
}
});
}
};
} // namespace alpaka::core
19 changes: 12 additions & 7 deletions include/alpaka/queue/cuda_hip/QueueUniformCudaHipRt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
# include <alpaka/core/Hip.hpp>
# endif

# include <alpaka/core/CallbackThread.hpp>

# include <condition_variable>
# include <functional>
# include <future>
# include <memory>
# include <mutex>
# include <thread>
Expand Down Expand Up @@ -85,6 +88,8 @@ namespace alpaka

public:
DevUniformCudaHipRt<TApi> const m_dev; //!< The device this queue is bound to.
core::CallbackThread m_callbackThread;

private:
typename TApi::Stream_t m_UniformCudaHipQueue;
};
Expand Down Expand Up @@ -114,6 +119,10 @@ namespace alpaka
{
return m_spQueueImpl->getNativeHandle();
}
auto getCallbackThread() -> core::CallbackThread&
{
return m_spQueueImpl->m_callbackThread;
}

public:
std::shared_ptr<QueueUniformCudaHipRtImpl<TApi>> m_spQueueImpl;
Expand Down Expand Up @@ -241,7 +250,7 @@ namespace alpaka
// callback thread. The CUDA/HIP thread signals the std::thread when it is ready to execute the task.
// The CUDA/HIP thread is waiting for the std::thread to signal that it is finished executing the task
// before it executes the next task in the queue (CUDA/HIP stream).
std::thread t(
auto f = queue.getCallbackThread().submit(std::packaged_task<void()>(
[spCallbackSynchronizationData, task]()
{
// If the callback has not yet been called, we wait for it.
Expand All @@ -261,15 +270,11 @@ namespace alpaka
spCallbackSynchronizationData->m_state = CallbackState::finished;
}
spCallbackSynchronizationData->m_event.notify_one();
});
}));

if constexpr(TBlocking)
{
t.join();
}
else
{
t.detach();
f.wait();
}
}
};
Expand Down

0 comments on commit dfc286a

Please sign in to comment.