Skip to content

Commit

Permalink
Remove usage of std::unique_ptr for ucxx::TrackedRequests members
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Jul 4, 2024
1 parent b6e4185 commit b394a72
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 55 deletions.
21 changes: 5 additions & 16 deletions cpp/include/ucxx/inflight_requests.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,18 @@ class Request;
*/
typedef std::map<const Request* const, std::shared_ptr<Request>> InflightRequestsMap;

/**
* @brief Pre-defined type for a pointer to an inflight request map.
*
* A pre-defined type for a pointer to an inflight request map, used as a convenience type.
*/
typedef std::unique_ptr<InflightRequestsMap> InflightRequestsMapPtr;

/**
* @brief A container for the different types of tracked requests.
*
* A container encapsulating the different types of handled tracked requests, currently
* those still valid (inflight), and those scheduled for cancelation (canceling).
*/
typedef struct TrackedRequests {
InflightRequestsMapPtr _inflight{
std::make_unique<InflightRequestsMap>()}; ///< Valid requests awaiting completion.
InflightRequestsMapPtr _canceling{
std::make_unique<InflightRequestsMap>()}; ///< Requests scheduled for cancelation.
std::unique_ptr<std::mutex> _mutex{
std::make_unique<std::mutex>()}; ///< Mutex to control access to inflight requests container
std::unique_ptr<std::mutex> _cancelMutex{
std::make_unique<std::mutex>()}; ///< Mutex to allow cancelation and prevent removing requests
///< simultaneously
InflightRequestsMap _inflight{}; ///< Valid requests awaiting completion.
InflightRequestsMap _canceling{}; ///< Requests scheduled for cancelation.
std::mutex _mutex{}; ///< Mutex to control access to inflight requests container
std::mutex
_cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously
} TrackedRequests;

/**
Expand Down
70 changes: 32 additions & 38 deletions cpp/src/inflight_requests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ InflightRequests::~InflightRequests() { cancelAll(); }
size_t InflightRequests::size()
{
std::scoped_lock localLock{_mutex};
std::lock_guard<std::mutex> lock(*_trackedRequests->_mutex);
return _trackedRequests->_inflight->size();
std::lock_guard<std::mutex> lock(_trackedRequests->_mutex);
return _trackedRequests->_inflight.size();
}

void InflightRequests::insert(std::shared_ptr<Request> request)
{
std::scoped_lock localLock{_mutex};
std::lock_guard<std::mutex> lock(*_trackedRequests->_mutex);
std::lock_guard<std::mutex> lock(_trackedRequests->_mutex);

_trackedRequests->_inflight->insert({request.get(), request});
_trackedRequests->_inflight.insert({request.get(), request});
}

void InflightRequests::merge(TrackedRequestsPtr trackedRequests)
Expand All @@ -33,27 +33,21 @@ void InflightRequests::merge(TrackedRequestsPtr trackedRequests)
if (trackedRequests == nullptr) return;

std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex,
*_trackedRequests->_mutex,
*trackedRequests->_cancelMutex,
*trackedRequests->_mutex};

if (trackedRequests->_inflight != nullptr)
_trackedRequests->_inflight->merge(*(trackedRequests->_inflight));
else
ucxx_error("Invalid _inflight object during merge");
if (trackedRequests->_canceling != nullptr)
_trackedRequests->_canceling->merge(*(trackedRequests->_canceling));
else
ucxx_error("Invalid _canceling object during merge");
std::scoped_lock lock{_trackedRequests->_cancelMutex,
_trackedRequests->_mutex,
trackedRequests->_cancelMutex,
trackedRequests->_mutex};

_trackedRequests->_inflight.merge(trackedRequests->_inflight);
_trackedRequests->_canceling.merge(trackedRequests->_canceling);
}
}

void InflightRequests::remove(const Request* const request)
{
do {
std::scoped_lock localLock{_mutex};
int result = std::try_lock(*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex);
int result = std::try_lock(_trackedRequests->_cancelMutex, _trackedRequests->_mutex);

/**
* `result` can be have one of three values:
Expand All @@ -68,9 +62,9 @@ void InflightRequests::remove(const Request* const request)
if (result == 0) {
return;
} else if (result == -1) {
auto search = _trackedRequests->_inflight->find(request);
auto search = _trackedRequests->_inflight.find(request);
decltype(search->second) tmpRequest;
if (search != _trackedRequests->_inflight->end()) {
if (search != _trackedRequests->_inflight.end()) {
/**
* If this is the last request to hold `std::shared_ptr<ucxx::Endpoint>` erasing it
* may cause the `ucxx::Endpoint`s destructor and subsequently the `closeBlocking()`
Expand All @@ -80,10 +74,10 @@ void InflightRequests::remove(const Request* const request)
* destroy the object upon this method's return.
*/
tmpRequest = search->second;
_trackedRequests->_inflight->erase(search);
_trackedRequests->_inflight.erase(search);
}
_trackedRequests->_cancelMutex->unlock();
_trackedRequests->_mutex->unlock();
_trackedRequests->_cancelMutex.unlock();
_trackedRequests->_mutex.unlock();
return;
}
} while (true);
Expand All @@ -95,12 +89,12 @@ size_t InflightRequests::dropCanceled()

{
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex};
for (auto it = _trackedRequests->_canceling->begin();
it != _trackedRequests->_canceling->end();) {
std::scoped_lock lock{_trackedRequests->_cancelMutex};
for (auto it = _trackedRequests->_canceling.begin();
it != _trackedRequests->_canceling.end();) {
auto request = it->second;
if (request != nullptr && request->getStatus() != UCS_INPROGRESS) {
it = _trackedRequests->_canceling->erase(it);
it = _trackedRequests->_canceling.erase(it);
++removed;
} else {
++it;
Expand All @@ -117,8 +111,8 @@ size_t InflightRequests::getCancelingSize()
size_t cancelingSize = 0;
{
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex};
cancelingSize = _trackedRequests->_canceling->size();
std::scoped_lock lock{_trackedRequests->_cancelMutex};
cancelingSize = _trackedRequests->_canceling.size();
}

return cancelingSize;
Expand All @@ -130,30 +124,30 @@ size_t InflightRequests::cancelAll()
size_t total;
{
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex};
total = _trackedRequests->_inflight->size();
std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex};
total = _trackedRequests->_inflight.size();

// Fast path when no requests have been registered or the map has been
// previously released.
if (total == 0) return 0;

toCancel = std::exchange(_trackedRequests->_inflight, std::make_unique<InflightRequestsMap>());
toCancel = std::exchange(_trackedRequests->_inflight, InflightRequestsMap());

ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total);

for (auto& r : *toCancel) {
for (auto& r : toCancel) {
auto request = r.second;
if (request != nullptr) { request->cancel(); }
}

_trackedRequests->_canceling->merge(*toCancel);
_trackedRequests->_canceling.merge(toCancel);

// dropCanceled();
for (auto it = _trackedRequests->_canceling->begin();
it != _trackedRequests->_canceling->end();) {
for (auto it = _trackedRequests->_canceling.begin();
it != _trackedRequests->_canceling.end();) {
auto request = it->second;
if (request != nullptr && request->getStatus() != UCS_INPROGRESS) {
it = _trackedRequests->_canceling->erase(it);
it = _trackedRequests->_canceling.erase(it);
} else {
++it;
}
Expand All @@ -166,7 +160,7 @@ size_t InflightRequests::cancelAll()
TrackedRequestsPtr InflightRequests::release()
{
std::scoped_lock localLock{_mutex};
std::scoped_lock lock{*_trackedRequests->_cancelMutex, *_trackedRequests->_mutex};
std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex};

return std::exchange(_trackedRequests, std::make_unique<TrackedRequests>());
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ void Worker::scheduleRequestCancel(TrackedRequestsPtr trackedRequests)
__func__,
this,
_handle,
trackedRequests->_inflight->size() + trackedRequests->_canceling->size());
trackedRequests->_inflight.size() + trackedRequests->_canceling.size());
_inflightRequestsToCancel->merge(std::move(trackedRequests));
}
}
Expand Down

0 comments on commit b394a72

Please sign in to comment.