Skip to content

Commit

Permalink
Fix segmentation fault in TrackedRequests
Browse files Browse the repository at this point in the history
Move `TrackedRequests` mutexes to the class itself, allowing locking the
objects even after releasing ownership. With this change, it's now
possible to lock the object within `InflightRewquests::merge` which was
a source of segmentation faults.
  • Loading branch information
pentschev committed Jul 2, 2024
1 parent 2aa56d8 commit c12808c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
6 changes: 3 additions & 3 deletions cpp/include/ucxx/inflight_requests.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ typedef struct TrackedRequests {
std::make_unique<InflightRequestsMap>()}; ///< Valid requests awaiting completion.
InflightRequestsMapPtr _canceling{
std::make_unique<InflightRequestsMap>()}; ///< Requests scheduled for cancelation.
std::mutex _mutex{}; ///< Mutex to control access to inflight requests container
std::mutex
_cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously
} TrackedRequests;

/**
Expand All @@ -61,9 +64,6 @@ class InflightRequests {
std::make_unique<TrackedRequests>()}; ///< Container storing pointers to all inflight
///< and in cancelation process requests known to
///< the owner of this object
std::mutex _mutex{}; ///< Mutex to control access to inflight requests container
std::mutex
_cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously

/**
* @brief Drop references to requests that completed cancelation.
Expand Down
23 changes: 13 additions & 10 deletions cpp/src/inflight_requests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@ InflightRequests::~InflightRequests() { cancelAll(); }

size_t InflightRequests::size()
{
std::lock_guard<std::mutex> lock(_mutex);
std::lock_guard<std::mutex> lock(_trackedRequests->_mutex);
return _trackedRequests->_inflight->size();
}

void InflightRequests::insert(std::shared_ptr<Request> request)
{
std::lock_guard<std::mutex> lock(_mutex);
std::lock_guard<std::mutex> lock(_trackedRequests->_mutex);

_trackedRequests->_inflight->insert({request.get(), request});
}

void InflightRequests::merge(TrackedRequestsPtr trackedRequests)
{
{
std::scoped_lock lock{_cancelMutex, _mutex};
std::scoped_lock lock{_trackedRequests->_cancelMutex,
_trackedRequests->_mutex,
trackedRequests->_cancelMutex,
trackedRequests->_mutex};
if (trackedRequests->_inflight != nullptr)
_trackedRequests->_inflight->merge(*(trackedRequests->_inflight));
else
Expand All @@ -43,7 +46,7 @@ void InflightRequests::merge(TrackedRequestsPtr trackedRequests)
void InflightRequests::remove(const Request* const request)
{
do {
int result = std::try_lock(_cancelMutex, _mutex);
int result = std::try_lock(_trackedRequests->_cancelMutex, _trackedRequests->_mutex);

/**
* `result` can be have one of three values:
Expand Down Expand Up @@ -72,8 +75,8 @@ void InflightRequests::remove(const Request* const request)
tmpRequest = search->second;
_trackedRequests->_inflight->erase(search);
}
_cancelMutex.unlock();
_mutex.unlock();
_trackedRequests->_cancelMutex.unlock();
_trackedRequests->_mutex.unlock();
return;
}
} while (true);
Expand All @@ -84,7 +87,7 @@ size_t InflightRequests::dropCanceled()
size_t removed = 0;

{
std::scoped_lock lock{_cancelMutex};
std::scoped_lock lock{_trackedRequests->_cancelMutex};
for (auto it = _trackedRequests->_canceling->begin();
it != _trackedRequests->_canceling->end();) {
auto request = it->second;
Expand All @@ -105,7 +108,7 @@ size_t InflightRequests::getCancelingSize()
dropCanceled();
size_t cancelingSize = 0;
{
std::scoped_lock lock{_cancelMutex};
std::scoped_lock lock{_trackedRequests->_cancelMutex};
cancelingSize = _trackedRequests->_canceling->size();
}

Expand All @@ -117,7 +120,7 @@ size_t InflightRequests::cancelAll()
decltype(_trackedRequests->_inflight) toCancel;
size_t total;
{
std::scoped_lock lock{_cancelMutex, _mutex};
std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex};
total = _trackedRequests->_inflight->size();

// Fast path when no requests have been registered or the map has been
Expand Down Expand Up @@ -152,7 +155,7 @@ size_t InflightRequests::cancelAll()

TrackedRequestsPtr InflightRequests::release()
{
std::scoped_lock lock{_cancelMutex, _mutex};
std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex};

return std::exchange(_trackedRequests, std::make_unique<TrackedRequests>());
}
Expand Down

0 comments on commit c12808c

Please sign in to comment.