From 83c5dbf64213412b82daeca569c5714d8f56b2d9 Mon Sep 17 00:00:00 2001 From: qicosmos Date: Mon, 25 Mar 2024 11:27:54 +0800 Subject: [PATCH] support lazy callback (#647) --- include/ylt/coro_io/coro_io.hpp | 63 ++++++++++++++++++- .../thirdparty/async_simple/coro/Collect.h | 4 +- src/coro_io/tests/test_coro_channel.cpp | 36 +++++++++++ 3 files changed, 99 insertions(+), 4 deletions(-) diff --git a/include/ylt/coro_io/coro_io.hpp b/include/ylt/coro_io/coro_io.hpp index 81cd605a4..40cb4c8d4 100644 --- a/include/ylt/coro_io/coro_io.hpp +++ b/include/ylt/coro_io/coro_io.hpp @@ -42,6 +42,10 @@ namespace coro_io { +template +constexpr inline bool is_lazy_v = + util::is_specialization_v, async_simple::coro::Lazy>; + template class callback_awaitor_base { private: @@ -395,9 +399,64 @@ async_simple::coro::Lazy +inline decltype(auto) select_impl(T &pair) { + using Func = std::tuple_element_t<1, std::remove_cvref_t>; + using ValueType = + typename std::tuple_element_t<0, std::remove_cvref_t>::ValueType; + using return_type = std::invoke_result_t>; + + auto &callback = std::get<1>(pair); + if constexpr (coro_io::is_lazy_v) { + auto executor = std::get<0>(pair).getExecutor(); + return std::make_pair( + std::move(std::get<0>(pair)), + [executor, callback = std::move(callback)](auto &&val) { + if (executor) { + callback(std::move(val)).via(executor).start([](auto &&) { + }); + } + else { + callback(std::move(val)).start([](auto &&) { + }); + } + }); + } + else { + return pair; + } +} + template -auto select(T &&...args) { - return async_simple::coro::collectAny(std::forward(args)...); +inline auto select(T &&...args) { + return async_simple::coro::collectAny(select_impl(args)...); +} + +template +inline auto select(std::vector vec, Callback callback) { + if constexpr (coro_io::is_lazy_v) { + std::vector executors; + for (auto &lazy : vec) { + executors.push_back(lazy.getExecutor()); + } + + return async_simple::coro::collectAny( + std::move(vec), + [executors, callback = std::move(callback)](size_t index, auto &&val) { + auto executor = executors[index]; + if (executor) { + callback(index, std::move(val)).via(executor).start([](auto &&) { + }); + } + else { + callback(index, std::move(val)).start([](auto &&) { + }); + } + }); + } + else { + return async_simple::coro::collectAny(std::move(vec), std::move(callback)); + } } template diff --git a/include/ylt/thirdparty/async_simple/coro/Collect.h b/include/ylt/thirdparty/async_simple/coro/Collect.h index 218cd7b43..16f239672 100644 --- a/include/ylt/thirdparty/async_simple/coro/Collect.h +++ b/include/ylt/thirdparty/async_simple/coro/Collect.h @@ -152,7 +152,7 @@ struct CollectAnyAwaiter { auto count = e->downCount(); if (count == size + 1) { r->_idx = i; - (*callback)(i, std::move(result)); + (void)(*callback)(i, std::move(result)); c.resume(); } }); @@ -222,7 +222,7 @@ struct CollectAnyVariadicPairAwaiter { callback](auto&& res) mutable { auto count = event->downCount(); if (count == std::tuple_size() + 1) { - callback(std::move(res)); + (void)callback(std::move(res)); *result = I; continuation.resume(); } diff --git a/src/coro_io/tests/test_coro_channel.cpp b/src/coro_io/tests/test_coro_channel.cpp index aceb43c8e..bdd343ebf 100644 --- a/src/coro_io/tests/test_coro_channel.cpp +++ b/src/coro_io/tests/test_coro_channel.cpp @@ -111,7 +111,43 @@ async_simple::coro::Lazy test_select_channel() { } } +void callback_lazy() { + using namespace async_simple::coro; + auto test0 = []() mutable -> Lazy { + co_return 41; + }; + + auto test1 = []() mutable -> Lazy { + co_return 42; + }; + + auto collectAnyLazy = [](auto&&... args) mutable -> Lazy { + co_return co_await collectAny(std::move(args)...); + }; + + syncAwait( + collectAnyLazy(std::pair{test1(), [&](auto&& val) mutable -> Lazy { + CHECK(val.value() == 42); + int r = co_await test0(); + int result = r + val.value(); + CHECK(result == 83); + }})); + + std::vector> input; + input.push_back(test1()); + + auto index = syncAwait(collectAnyLazy( + std::move(input), [&test0](size_t index, auto val) mutable -> Lazy { + CHECK(val.value() == 42); + int r = co_await test0(); + int result = r + val.value(); + CHECK(result == 83); + })); + CHECK(index == 0); +} + TEST_CASE("test channel send recieve, test select channel and coroutine") { async_simple::coro::syncAwait(test_coro_channel()); async_simple::coro::syncAwait(test_select_channel()); + callback_lazy(); } \ No newline at end of file