From b6e4185d7eff0ca2e36baa32b58f8db8933104d9 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 4 Jul 2024 09:21:44 -0700 Subject: [PATCH] Add mutex for `ucxx::InflightRequests` Add mutex to protect access to `ucxx::InflightRequests` from race conditions, particularly after release/before merge of `ucxx::TrackedRequests`. This has room for improvement, perhaps moving most of the implementation into `ucxx::TrackedRequests`. --- cpp/include/ucxx/inflight_requests.h | 9 ++++--- cpp/src/inflight_requests.cpp | 37 ++++++++++++++++++---------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index d71caf88..a10b8f3e 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -39,9 +39,11 @@ typedef struct TrackedRequests { std::make_unique()}; ///< Valid requests awaiting completion. InflightRequestsMapPtr _canceling{ std::make_unique()}; ///< Requests scheduled for cancelation. - std::mutex _mutex{}; ///< Mutex to control access to inflight requests container - std::mutex - _cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously + std::unique_ptr _mutex{ + std::make_unique()}; ///< Mutex to control access to inflight requests container + std::unique_ptr _cancelMutex{ + std::make_unique()}; ///< Mutex to allow cancelation and prevent removing requests + ///< simultaneously } TrackedRequests; /** @@ -64,6 +66,7 @@ class InflightRequests { std::make_unique()}; ///< Container storing pointers to all inflight ///< and in cancelation process requests known to ///< the owner of this object + std::recursive_mutex _mutex{}; ///< Mutex to control access to class resources /** * @brief Drop references to requests that completed cancelation. diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 4aec90a5..4aafec9c 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -14,13 +14,15 @@ InflightRequests::~InflightRequests() { cancelAll(); } size_t InflightRequests::size() { - std::lock_guard lock(_trackedRequests->_mutex); + std::scoped_lock localLock{_mutex}; + std::lock_guard lock(*_trackedRequests->_mutex); return _trackedRequests->_inflight->size(); } void InflightRequests::insert(std::shared_ptr request) { - std::lock_guard lock(_trackedRequests->_mutex); + std::scoped_lock localLock{_mutex}; + std::lock_guard lock(*_trackedRequests->_mutex); _trackedRequests->_inflight->insert({request.get(), request}); } @@ -28,10 +30,14 @@ void InflightRequests::insert(std::shared_ptr request) void InflightRequests::merge(TrackedRequestsPtr trackedRequests) { { - std::scoped_lock lock{_trackedRequests->_cancelMutex, - _trackedRequests->_mutex, - trackedRequests->_cancelMutex, - trackedRequests->_mutex}; + if (trackedRequests == nullptr) return; + + std::scoped_lock localLock{_mutex}; + std::scoped_lock lock{*_trackedRequests->_cancelMutex, + *_trackedRequests->_mutex, + *trackedRequests->_cancelMutex, + *trackedRequests->_mutex}; + if (trackedRequests->_inflight != nullptr) _trackedRequests->_inflight->merge(*(trackedRequests->_inflight)); else @@ -46,7 +52,8 @@ void InflightRequests::merge(TrackedRequestsPtr trackedRequests) void InflightRequests::remove(const Request* const request) { do { - int result = std::try_lock(_trackedRequests->_cancelMutex, _trackedRequests->_mutex); + std::scoped_lock localLock{_mutex}; + int result = std::try_lock(*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex); /** * `result` can be have one of three values: @@ -75,8 +82,8 @@ void InflightRequests::remove(const Request* const request) tmpRequest = search->second; _trackedRequests->_inflight->erase(search); } - _trackedRequests->_cancelMutex.unlock(); - _trackedRequests->_mutex.unlock(); + _trackedRequests->_cancelMutex->unlock(); + _trackedRequests->_mutex->unlock(); return; } } while (true); @@ -87,7 +94,8 @@ size_t InflightRequests::dropCanceled() size_t removed = 0; { - std::scoped_lock lock{_trackedRequests->_cancelMutex}; + std::scoped_lock localLock{_mutex}; + std::scoped_lock lock{*_trackedRequests->_cancelMutex}; for (auto it = _trackedRequests->_canceling->begin(); it != _trackedRequests->_canceling->end();) { auto request = it->second; @@ -108,7 +116,8 @@ size_t InflightRequests::getCancelingSize() dropCanceled(); size_t cancelingSize = 0; { - std::scoped_lock lock{_trackedRequests->_cancelMutex}; + std::scoped_lock localLock{_mutex}; + std::scoped_lock lock{*_trackedRequests->_cancelMutex}; cancelingSize = _trackedRequests->_canceling->size(); } @@ -120,7 +129,8 @@ size_t InflightRequests::cancelAll() decltype(_trackedRequests->_inflight) toCancel; size_t total; { - std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex}; + std::scoped_lock localLock{_mutex}; + std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex}; total = _trackedRequests->_inflight->size(); // Fast path when no requests have been registered or the map has been @@ -155,7 +165,8 @@ size_t InflightRequests::cancelAll() TrackedRequestsPtr InflightRequests::release() { - std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex}; + std::scoped_lock localLock{_mutex}; + std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex}; return std::exchange(_trackedRequests, std::make_unique()); }