diff --git a/cpp/include/ucxx/worker_progress_thread.h b/cpp/include/ucxx/worker_progress_thread.h index 9c9ac44d..38aae033 100644 --- a/cpp/include/ucxx/worker_progress_thread.h +++ b/cpp/include/ucxx/worker_progress_thread.h @@ -13,6 +13,7 @@ namespace ucxx { +typedef std::function SignalWorkerFunction; typedef std::function ProgressThreadStartCallback; typedef void* ProgressThreadStartCallbackArg; @@ -21,6 +22,9 @@ class WorkerProgressThread { std::thread _thread{}; ///< Thread object bool _stop{false}; ///< Signal to stop on next iteration bool _pollingMode{false}; ///< Whether thread will use polling mode to progress + SignalWorkerFunction _signalWorkerFunction{ + nullptr}; ///< Function signaling worker to wake the progress event (when _pollingMode is + ///< `false`) ProgressThreadStartCallback _startCallback{ nullptr}; ///< Callback to execute at start of the progress thread ProgressThreadStartCallbackArg _startCallbackArg{ @@ -71,6 +75,8 @@ class WorkerProgressThread { * @param[in] pollingMode whether the thread should use polling mode to * progress. * @param[in] progressFunction user-defined progress function implementation. + * @param[in] signalWorkerFunction user-defined function to wake the worker + * progress event (when `pollingMode` is `false`). * @param[in] startCallback user-defined callback function to be executed * at the start of the progress thread. * @param[in] startCallbackArg an argument to be passed to the start callback. @@ -79,6 +85,7 @@ class WorkerProgressThread { */ WorkerProgressThread(const bool pollingMode, std::function progressFunction, + std::function signalWorkerFunction, ProgressThreadStartCallback startCallback, ProgressThreadStartCallbackArg startCallbackArg, std::shared_ptr delayedSubmissionCollection); diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 83aa8756..403b17c1 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -254,19 +254,18 @@ void Worker::startProgressThread(const bool pollingMode) if (!pollingMode) initBlockingProgressMode(); auto progressFunction = pollingMode ? std::bind(&Worker::progress, this) : std::bind(&Worker::progressWorkerEvent, this); + auto signalWorkerFunction = + pollingMode ? std::function{[]() {}} : std::bind(&Worker::signal, this); _progressThread = std::make_shared(pollingMode, progressFunction, + signalWorkerFunction, _progressThreadStartCallback, _progressThreadStartCallbackArg, _delayedSubmissionCollection); } -void Worker::stopProgressThreadNoWarn() -{ - if (_progressThread && !_progressThread->pollingMode()) signal(); - _progressThread = nullptr; -} +void Worker::stopProgressThreadNoWarn() { _progressThread = nullptr; } void Worker::stopProgressThread() { diff --git a/cpp/src/worker_progress_thread.cpp b/cpp/src/worker_progress_thread.cpp index f9156b14..e772ec9d 100644 --- a/cpp/src/worker_progress_thread.cpp +++ b/cpp/src/worker_progress_thread.cpp @@ -28,10 +28,14 @@ void WorkerProgressThread::progressUntilSync( WorkerProgressThread::WorkerProgressThread( const bool pollingMode, std::function progressFunction, + std::function signalWorkerFunction, ProgressThreadStartCallback startCallback, ProgressThreadStartCallbackArg startCallbackArg, std::shared_ptr delayedSubmissionCollection) - : _pollingMode(pollingMode), _startCallback(startCallback), _startCallbackArg(startCallbackArg) + : _pollingMode(pollingMode), + _signalWorkerFunction(signalWorkerFunction), + _startCallback(startCallback), + _startCallbackArg(startCallbackArg) { _thread = std::thread(WorkerProgressThread::progressUntilSync, progressFunction, @@ -49,6 +53,7 @@ WorkerProgressThread::~WorkerProgressThread() } _stop = true; + _signalWorkerFunction(); _thread.join(); }