From 4a962e29127d006774ea0eb057082abf35fe2ab3 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 14 Aug 2024 02:22:35 -0700 Subject: [PATCH] Test tag user callback with discarded return We were recently inquired by a user whether relying only on the `ucs_status_t` passed to the user callback and discarding the `ucxx::RequestTag` return from `ucxx::Endpoint::tagSend` and `ucxx::Endpoint::tagRecv` is valid. This turns out to be a valid usage, althought I'm unsure whether encouraging this is a good idea, since requests such as `ucxx::Endpoint::amRecv` require that the user retrieve the resulting buffer from the returned `ucxx::RequestAm`, and if this is discarded the buffer is lost. This PR adds a test for the aforementioned use case but makes no changes to documentation to prevent encouraging this pattern until we can decide whether we should support it or not. For requests such as `ucxx::Endpoint::amRecv`, it might be worth studying whether we could pass the resulting buffer and any other attributes associated with it to the callback, in that case we may be able to provide a safe pattern to always use the callback if the user doesn't want to keep a referencce to the returned `ucxx::Request` object. --- cpp/tests/request.cpp | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 371453b0..75ca48ec 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -363,6 +363,49 @@ TEST_P(RequestTest, TagUserCallback) for (const auto request : requests) ASSERT_THAT(request->getStatus(), UCS_OK); + for (const auto status : requestStatus) + ASSERT_THAT(status, UCS_OK); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, TagUserCallbackDiscardReturn) +{ + allocate(); + + std::vector requestStatus(2, UCS_INPROGRESS); + + auto checkStatus = [&requestStatus](ucs_status_t status, ::ucxx::RequestCallbackUserData data) { + auto idx = *std::static_pointer_cast(data); + requestStatus[idx] = status; + }; + + auto checkCompletion = [&requestStatus, this]() { + std::vector completed(2, 0); + while (std::accumulate(completed.begin(), completed.end(), 0) != 2) { + _progressWorker(); + std::transform( + requestStatus.begin(), requestStatus.end(), completed.begin(), [](ucs_status_t status) { + return status == UCS_INPROGRESS ? 0 : 1; + }); + } + }; + + auto sendIndex = std::make_shared(0u); + auto recvIndex = std::make_shared(1u); + + // Submit and wait for transfers to complete + std::ignore = + _ep->tagSend(_sendPtr[0], _messageSize, ucxx::Tag{0}, false, checkStatus, sendIndex); + std::ignore = _ep->tagRecv( + _recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull, false, checkStatus, recvIndex); + checkCompletion(); + + copyResults(); + + for (const auto status : requestStatus) + ASSERT_THAT(status, UCS_OK); // Assert data correctness ASSERT_THAT(_recv[0], ContainerEq(_send[0]));