From 92d276ed694a3b9d4b555bde5c98e9017d83a5d3 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 11 Apr 2023 05:33:43 -0700 Subject: [PATCH] Signal the worker progress event from the worker progress thread Without signaling the worker progress event from the worker progress thread while shutting down the thread it may occur that the thread hangs as the event never gets awaken due to the lack of tasks to progress. This change ensures the worker progress event is signaled before joining the thread. --- cpp/include/ucxx/worker_progress_thread.h | 7 +++++++ cpp/src/worker.cpp | 9 ++++----- cpp/src/worker_progress_thread.cpp | 7 ++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/cpp/include/ucxx/worker_progress_thread.h b/cpp/include/ucxx/worker_progress_thread.h index 9c9ac44d..38aae033 100644 --- a/cpp/include/ucxx/worker_progress_thread.h +++ b/cpp/include/ucxx/worker_progress_thread.h @@ -13,6 +13,7 @@ namespace ucxx { +typedef std::function SignalWorkerFunction; typedef std::function ProgressThreadStartCallback; typedef void* ProgressThreadStartCallbackArg; @@ -21,6 +22,9 @@ class WorkerProgressThread { std::thread _thread{}; ///< Thread object bool _stop{false}; ///< Signal to stop on next iteration bool _pollingMode{false}; ///< Whether thread will use polling mode to progress + SignalWorkerFunction _signalWorkerFunction{ + nullptr}; ///< Function signaling worker to wake the progress event (when _pollingMode is + ///< `false`) ProgressThreadStartCallback _startCallback{ nullptr}; ///< Callback to execute at start of the progress thread ProgressThreadStartCallbackArg _startCallbackArg{ @@ -71,6 +75,8 @@ class WorkerProgressThread { * @param[in] pollingMode whether the thread should use polling mode to * progress. * @param[in] progressFunction user-defined progress function implementation. + * @param[in] signalWorkerFunction user-defined function to wake the worker + * progress event (when `pollingMode` is `false`). * @param[in] startCallback user-defined callback function to be executed * at the start of the progress thread. * @param[in] startCallbackArg an argument to be passed to the start callback. @@ -79,6 +85,7 @@ class WorkerProgressThread { */ WorkerProgressThread(const bool pollingMode, std::function progressFunction, + std::function signalWorkerFunction, ProgressThreadStartCallback startCallback, ProgressThreadStartCallbackArg startCallbackArg, std::shared_ptr delayedSubmissionCollection); diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 83aa8756..403b17c1 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -254,19 +254,18 @@ void Worker::startProgressThread(const bool pollingMode) if (!pollingMode) initBlockingProgressMode(); auto progressFunction = pollingMode ? std::bind(&Worker::progress, this) : std::bind(&Worker::progressWorkerEvent, this); + auto signalWorkerFunction = + pollingMode ? std::function{[]() {}} : std::bind(&Worker::signal, this); _progressThread = std::make_shared(pollingMode, progressFunction, + signalWorkerFunction, _progressThreadStartCallback, _progressThreadStartCallbackArg, _delayedSubmissionCollection); } -void Worker::stopProgressThreadNoWarn() -{ - if (_progressThread && !_progressThread->pollingMode()) signal(); - _progressThread = nullptr; -} +void Worker::stopProgressThreadNoWarn() { _progressThread = nullptr; } void Worker::stopProgressThread() { diff --git a/cpp/src/worker_progress_thread.cpp b/cpp/src/worker_progress_thread.cpp index f9156b14..e772ec9d 100644 --- a/cpp/src/worker_progress_thread.cpp +++ b/cpp/src/worker_progress_thread.cpp @@ -28,10 +28,14 @@ void WorkerProgressThread::progressUntilSync( WorkerProgressThread::WorkerProgressThread( const bool pollingMode, std::function progressFunction, + std::function signalWorkerFunction, ProgressThreadStartCallback startCallback, ProgressThreadStartCallbackArg startCallbackArg, std::shared_ptr delayedSubmissionCollection) - : _pollingMode(pollingMode), _startCallback(startCallback), _startCallbackArg(startCallbackArg) + : _pollingMode(pollingMode), + _signalWorkerFunction(signalWorkerFunction), + _startCallback(startCallback), + _startCallbackArg(startCallbackArg) { _thread = std::thread(WorkerProgressThread::progressUntilSync, progressFunction, @@ -49,6 +53,7 @@ WorkerProgressThread::~WorkerProgressThread() } _stop = true; + _signalWorkerFunction(); _thread.join(); }