From b394a723878f27f3baace42a5fa12795e8303081 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 4 Jul 2024 10:00:03 -0700 Subject: [PATCH] Remove usage of `std::unique_ptr` for `ucxx::TrackedRequests` members --- cpp/include/ucxx/inflight_requests.h | 21 ++------- cpp/src/inflight_requests.cpp | 70 +++++++++++++--------------- cpp/src/worker.cpp | 2 +- 3 files changed, 38 insertions(+), 55 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index a10b8f3e..68a8216d 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -21,13 +21,6 @@ class Request; */ typedef std::map> InflightRequestsMap; -/** - * @brief Pre-defined type for a pointer to an inflight request map. - * - * A pre-defined type for a pointer to an inflight request map, used as a convenience type. - */ -typedef std::unique_ptr InflightRequestsMapPtr; - /** * @brief A container for the different types of tracked requests. * @@ -35,15 +28,11 @@ typedef std::unique_ptr InflightRequestsMapPtr; * those still valid (inflight), and those scheduled for cancelation (canceling). */ typedef struct TrackedRequests { - InflightRequestsMapPtr _inflight{ - std::make_unique()}; ///< Valid requests awaiting completion. - InflightRequestsMapPtr _canceling{ - std::make_unique()}; ///< Requests scheduled for cancelation. - std::unique_ptr _mutex{ - std::make_unique()}; ///< Mutex to control access to inflight requests container - std::unique_ptr _cancelMutex{ - std::make_unique()}; ///< Mutex to allow cancelation and prevent removing requests - ///< simultaneously + InflightRequestsMap _inflight{}; ///< Valid requests awaiting completion. + InflightRequestsMap _canceling{}; ///< Requests scheduled for cancelation. + std::mutex _mutex{}; ///< Mutex to control access to inflight requests container + std::mutex + _cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously } TrackedRequests; /** diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 4aafec9c..eca05c40 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -15,16 +15,16 @@ InflightRequests::~InflightRequests() { cancelAll(); } size_t InflightRequests::size() { std::scoped_lock localLock{_mutex}; - std::lock_guard lock(*_trackedRequests->_mutex); - return _trackedRequests->_inflight->size(); + std::lock_guard lock(_trackedRequests->_mutex); + return _trackedRequests->_inflight.size(); } void InflightRequests::insert(std::shared_ptr request) { std::scoped_lock localLock{_mutex}; - std::lock_guard lock(*_trackedRequests->_mutex); + std::lock_guard lock(_trackedRequests->_mutex); - _trackedRequests->_inflight->insert({request.get(), request}); + _trackedRequests->_inflight.insert({request.get(), request}); } void InflightRequests::merge(TrackedRequestsPtr trackedRequests) @@ -33,19 +33,13 @@ void InflightRequests::merge(TrackedRequestsPtr trackedRequests) if (trackedRequests == nullptr) return; std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{*_trackedRequests->_cancelMutex, - *_trackedRequests->_mutex, - *trackedRequests->_cancelMutex, - *trackedRequests->_mutex}; - - if (trackedRequests->_inflight != nullptr) - _trackedRequests->_inflight->merge(*(trackedRequests->_inflight)); - else - ucxx_error("Invalid _inflight object during merge"); - if (trackedRequests->_canceling != nullptr) - _trackedRequests->_canceling->merge(*(trackedRequests->_canceling)); - else - ucxx_error("Invalid _canceling object during merge"); + std::scoped_lock lock{_trackedRequests->_cancelMutex, + _trackedRequests->_mutex, + trackedRequests->_cancelMutex, + trackedRequests->_mutex}; + + _trackedRequests->_inflight.merge(trackedRequests->_inflight); + _trackedRequests->_canceling.merge(trackedRequests->_canceling); } } @@ -53,7 +47,7 @@ void InflightRequests::remove(const Request* const request) { do { std::scoped_lock localLock{_mutex}; - int result = std::try_lock(*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex); + int result = std::try_lock(_trackedRequests->_cancelMutex, _trackedRequests->_mutex); /** * `result` can be have one of three values: @@ -68,9 +62,9 @@ void InflightRequests::remove(const Request* const request) if (result == 0) { return; } else if (result == -1) { - auto search = _trackedRequests->_inflight->find(request); + auto search = _trackedRequests->_inflight.find(request); decltype(search->second) tmpRequest; - if (search != _trackedRequests->_inflight->end()) { + if (search != _trackedRequests->_inflight.end()) { /** * If this is the last request to hold `std::shared_ptr` erasing it * may cause the `ucxx::Endpoint`s destructor and subsequently the `closeBlocking()` @@ -80,10 +74,10 @@ void InflightRequests::remove(const Request* const request) * destroy the object upon this method's return. */ tmpRequest = search->second; - _trackedRequests->_inflight->erase(search); + _trackedRequests->_inflight.erase(search); } - _trackedRequests->_cancelMutex->unlock(); - _trackedRequests->_mutex->unlock(); + _trackedRequests->_cancelMutex.unlock(); + _trackedRequests->_mutex.unlock(); return; } } while (true); @@ -95,12 +89,12 @@ size_t InflightRequests::dropCanceled() { std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{*_trackedRequests->_cancelMutex}; - for (auto it = _trackedRequests->_canceling->begin(); - it != _trackedRequests->_canceling->end();) { + std::scoped_lock lock{_trackedRequests->_cancelMutex}; + for (auto it = _trackedRequests->_canceling.begin(); + it != _trackedRequests->_canceling.end();) { auto request = it->second; if (request != nullptr && request->getStatus() != UCS_INPROGRESS) { - it = _trackedRequests->_canceling->erase(it); + it = _trackedRequests->_canceling.erase(it); ++removed; } else { ++it; @@ -117,8 +111,8 @@ size_t InflightRequests::getCancelingSize() size_t cancelingSize = 0; { std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{*_trackedRequests->_cancelMutex}; - cancelingSize = _trackedRequests->_canceling->size(); + std::scoped_lock lock{_trackedRequests->_cancelMutex}; + cancelingSize = _trackedRequests->_canceling.size(); } return cancelingSize; @@ -130,30 +124,30 @@ size_t InflightRequests::cancelAll() size_t total; { std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex}; - total = _trackedRequests->_inflight->size(); + std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex}; + total = _trackedRequests->_inflight.size(); // Fast path when no requests have been registered or the map has been // previously released. if (total == 0) return 0; - toCancel = std::exchange(_trackedRequests->_inflight, std::make_unique()); + toCancel = std::exchange(_trackedRequests->_inflight, InflightRequestsMap()); ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); - for (auto& r : *toCancel) { + for (auto& r : toCancel) { auto request = r.second; if (request != nullptr) { request->cancel(); } } - _trackedRequests->_canceling->merge(*toCancel); + _trackedRequests->_canceling.merge(toCancel); // dropCanceled(); - for (auto it = _trackedRequests->_canceling->begin(); - it != _trackedRequests->_canceling->end();) { + for (auto it = _trackedRequests->_canceling.begin(); + it != _trackedRequests->_canceling.end();) { auto request = it->second; if (request != nullptr && request->getStatus() != UCS_INPROGRESS) { - it = _trackedRequests->_canceling->erase(it); + it = _trackedRequests->_canceling.erase(it); } else { ++it; } @@ -166,7 +160,7 @@ size_t InflightRequests::cancelAll() TrackedRequestsPtr InflightRequests::release() { std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex}; + std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex}; return std::exchange(_trackedRequests, std::make_unique()); } diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 480ee0e0..d6961d94 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -511,7 +511,7 @@ void Worker::scheduleRequestCancel(TrackedRequestsPtr trackedRequests) __func__, this, _handle, - trackedRequests->_inflight->size() + trackedRequests->_canceling->size()); + trackedRequests->_inflight.size() + trackedRequests->_canceling.size()); _inflightRequestsToCancel->merge(std::move(trackedRequests)); } }