Skip to content

Commit

Permalink
Add mutex for ucxx::InflightRequests
Browse files Browse the repository at this point in the history
Add mutex to protect access to `ucxx::InflightRequests` from race
conditions, particularly after release/before merge of
`ucxx::TrackedRequests`. This has room for improvement, perhaps moving
most of the implementation into `ucxx::TrackedRequests`.
  • Loading branch information
pentschev committed Jul 4, 2024
1 parent 3836700 commit b6e4185
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
9 changes: 6 additions & 3 deletions cpp/include/ucxx/inflight_requests.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ typedef struct TrackedRequests {
std::make_unique<InflightRequestsMap>()}; ///< Valid requests awaiting completion.
InflightRequestsMapPtr _canceling{
std::make_unique<InflightRequestsMap>()}; ///< Requests scheduled for cancelation.
std::mutex _mutex{}; ///< Mutex to control access to inflight requests container
std::mutex
_cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously
std::unique_ptr<std::mutex> _mutex{
std::make_unique<std::mutex>()}; ///< Mutex to control access to inflight requests container
std::unique_ptr<std::mutex> _cancelMutex{
std::make_unique<std::mutex>()}; ///< Mutex to allow cancelation and prevent removing requests
///< simultaneously
} TrackedRequests;

/**
Expand All @@ -64,6 +66,7 @@ class InflightRequests {
std::make_unique<TrackedRequests>()}; ///< Container storing pointers to all inflight
///< and in cancelation process requests known to
///< the owner of this object
std::recursive_mutex _mutex{}; ///< Mutex to control access to class resources

/**
* @brief Drop references to requests that completed cancelation.
Expand Down
37 changes: 24 additions & 13 deletions cpp/src/inflight_requests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,30 @@ InflightRequests::~InflightRequests() { cancelAll(); }

size_t InflightRequests::size()
{
std::lock_guard<std::mutex> lock(_trackedRequests->_mutex);
std::scoped_lock localLock{_mutex};
std::lock_guard<std::mutex> lock(*_trackedRequests->_mutex);
return _trackedRequests->_inflight->size();
}

void InflightRequests::insert(std::shared_ptr<Request> request)
{
std::lock_guard<std::mutex> lock(_trackedRequests->_mutex);
std::scoped_lock localLock{_mutex};
std::lock_guard<std::mutex> lock(*_trackedRequests->_mutex);

_trackedRequests->_inflight->insert({request.get(), request});
}

void InflightRequests::merge(TrackedRequestsPtr trackedRequests)
{
{
std::scoped_lock lock{_trackedRequests->_cancelMutex,
_trackedRequests->_mutex,
trackedRequests->_cancelMutex,
trackedRequests->_mutex};
if (trackedRequests == nullptr) return;

std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex,
*_trackedRequests->_mutex,
*trackedRequests->_cancelMutex,
*trackedRequests->_mutex};

if (trackedRequests->_inflight != nullptr)
_trackedRequests->_inflight->merge(*(trackedRequests->_inflight));
else
Expand All @@ -46,7 +52,8 @@ void InflightRequests::merge(TrackedRequestsPtr trackedRequests)
void InflightRequests::remove(const Request* const request)
{
do {
int result = std::try_lock(_trackedRequests->_cancelMutex, _trackedRequests->_mutex);
std::scoped_lock localLock{_mutex};
int result = std::try_lock(*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex);

/**
* `result` can be have one of three values:
Expand Down Expand Up @@ -75,8 +82,8 @@ void InflightRequests::remove(const Request* const request)
tmpRequest = search->second;
_trackedRequests->_inflight->erase(search);
}
_trackedRequests->_cancelMutex.unlock();
_trackedRequests->_mutex.unlock();
_trackedRequests->_cancelMutex->unlock();
_trackedRequests->_mutex->unlock();
return;
}
} while (true);
Expand All @@ -87,7 +94,8 @@ size_t InflightRequests::dropCanceled()
size_t removed = 0;

{
std::scoped_lock lock{_trackedRequests->_cancelMutex};
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex};
for (auto it = _trackedRequests->_canceling->begin();
it != _trackedRequests->_canceling->end();) {
auto request = it->second;
Expand All @@ -108,7 +116,8 @@ size_t InflightRequests::getCancelingSize()
dropCanceled();
size_t cancelingSize = 0;
{
std::scoped_lock lock{_trackedRequests->_cancelMutex};
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex};
cancelingSize = _trackedRequests->_canceling->size();
}

Expand All @@ -120,7 +129,8 @@ size_t InflightRequests::cancelAll()
decltype(_trackedRequests->_inflight) toCancel;
size_t total;
{
std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex};
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex};
total = _trackedRequests->_inflight->size();

// Fast path when no requests have been registered or the map has been
Expand Down Expand Up @@ -155,7 +165,8 @@ size_t InflightRequests::cancelAll()

TrackedRequestsPtr InflightRequests::release()
{
std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex};
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex};

return std::exchange(_trackedRequests, std::make_unique<TrackedRequests>());
}
Expand Down

0 comments on commit b6e4185

Please sign in to comment.