diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 371453b0..75ca48ec 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -363,6 +363,49 @@ TEST_P(RequestTest, TagUserCallback) for (const auto request : requests) ASSERT_THAT(request->getStatus(), UCS_OK); + for (const auto status : requestStatus) + ASSERT_THAT(status, UCS_OK); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, TagUserCallbackDiscardReturn) +{ + allocate(); + + std::vector requestStatus(2, UCS_INPROGRESS); + + auto checkStatus = [&requestStatus](ucs_status_t status, ::ucxx::RequestCallbackUserData data) { + auto idx = *std::static_pointer_cast(data); + requestStatus[idx] = status; + }; + + auto checkCompletion = [&requestStatus, this]() { + std::vector completed(2, 0); + while (std::accumulate(completed.begin(), completed.end(), 0) != 2) { + _progressWorker(); + std::transform( + requestStatus.begin(), requestStatus.end(), completed.begin(), [](ucs_status_t status) { + return status == UCS_INPROGRESS ? 0 : 1; + }); + } + }; + + auto sendIndex = std::make_shared(0u); + auto recvIndex = std::make_shared(1u); + + // Submit and wait for transfers to complete + std::ignore = + _ep->tagSend(_sendPtr[0], _messageSize, ucxx::Tag{0}, false, checkStatus, sendIndex); + std::ignore = _ep->tagRecv( + _recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull, false, checkStatus, recvIndex); + checkCompletion(); + + copyResults(); + + for (const auto status : requestStatus) + ASSERT_THAT(status, UCS_OK); // Assert data correctness ASSERT_THAT(_recv[0], ContainerEq(_send[0]));