diff --git a/include/cinatra/ylt/coro_io/coro_io.hpp b/include/cinatra/ylt/coro_io/coro_io.hpp index e25f0a17..ac2ca126 100644 --- a/include/cinatra/ylt/coro_io/coro_io.hpp +++ b/include/cinatra/ylt/coro_io/coro_io.hpp @@ -4,6 +4,8 @@ #include #include +#include "async_simple/coro/Collect.h" + #if defined(YLT_ENABLE_SSL) || defined(CINATRA_ENABLE_SSL) #include #endif @@ -333,8 +335,24 @@ post(Func func, co_return co_await awaitor.await_resume(helper); } +template +struct coro_channel + : public asio::experimental::channel { + using return_type = R; + using ValueType = std::pair; + using asio::experimental::channel::channel; +}; + +template +inline coro_channel create_channel( + size_t capacity, + asio::io_context::executor_type executor = + coro_io::get_global_block_executor()->get_asio_executor()) { + return coro_channel(executor, capacity); +} + template -async_simple::coro::Lazy async_send( +inline async_simple::coro::Lazy async_send( asio::experimental::channel &channel, T val) { callback_awaitor awaitor; co_return co_await awaitor.await_resume( @@ -345,10 +363,12 @@ async_simple::coro::Lazy async_send( }); } -template -async_simple::coro::Lazy> async_receive( - asio::experimental::channel &channel) { - callback_awaitor> awaitor; +template +async_simple::coro::Lazy> inline async_receive(Channel &channel) { + callback_awaitor> + awaitor; co_return co_await awaitor.await_resume([&](auto handler) { channel.async_receive([handler](auto ec, auto val) { handler.set_value_then_resume(std::make_pair(ec, std::move(val))); @@ -356,6 +376,11 @@ async_simple::coro::Lazy> async_receive( }); } +template +auto select(T &&...args) { + return async_simple::coro::collectAny(std::forward(args)...); +} + template std::pair read_some(Socket &sock, AsioBuffer &&buffer) { diff --git a/tests/test_cinatra.cpp b/tests/test_cinatra.cpp index 47d9f584..a5689685 100644 --- a/tests/test_cinatra.cpp +++ b/tests/test_cinatra.cpp @@ -188,8 +188,7 @@ TEST_CASE("test cinatra::string SSO to no SSO") { } TEST_CASE("test coro channel") { - auto ctx = coro_io::get_global_block_executor()->get_asio_executor(); - asio::experimental::channel ch(ctx, 10000); + auto ch = coro_io::create_channel(1000); auto ec = async_simple::coro::syncAwait(coro_io::async_send(ch, 41)); CHECK(!ec); ec = async_simple::coro::syncAwait(coro_io::async_send(ch, 42)); @@ -198,16 +197,126 @@ TEST_CASE("test coro channel") { std::error_code err; int val; std::tie(err, val) = - async_simple::coro::syncAwait(coro_io::async_receive(ch)); + async_simple::coro::syncAwait(coro_io::async_receive(ch)); CHECK(!err); CHECK(val == 41); std::tie(err, val) = - async_simple::coro::syncAwait(coro_io::async_receive(ch)); + async_simple::coro::syncAwait(coro_io::async_receive(ch)); CHECK(!err); CHECK(val == 42); } +async_simple::coro::Lazy test_select_channel() { + using namespace coro_io; + using namespace async_simple; + using namespace async_simple::coro; + + auto ch1 = coro_io::create_channel(1000); + auto ch2 = coro_io::create_channel(1000); + + co_await async_send(ch1, 41); + co_await async_send(ch2, 42); + + std::array arr{41, 42}; + int val; + + size_t index = + co_await select(std::pair{async_receive(ch1), + [&val](auto value) { + auto [ec, r] = value.value(); + val = r; + }}, + std::pair{async_receive(ch2), [&val](auto value) { + auto [ec, r] = value.value(); + val = r; + }}); + + CHECK(val == arr[index]); + + co_await async_send(ch1, 41); + co_await async_send(ch2, 42); + + std::vector>> vec; + vec.push_back(async_receive(ch1)); + vec.push_back(async_receive(ch2)); + + index = co_await select(std::move(vec), [&](size_t i, auto result) { + val = result.value().second; + }); + CHECK(val == arr[index]); + + period_timer timer1(coro_io::get_global_executor()); + timer1.expires_after(100ms); + period_timer timer2(coro_io::get_global_executor()); + timer2.expires_after(200ms); + + int val1; + index = co_await select(std::pair{timer1.async_await(), + [&](auto val) { + CHECK(val.value()); + val1 = 0; + }}, + std::pair{timer2.async_await(), [&](auto val) { + CHECK(val.value()); + val1 = 0; + }}); + CHECK(index == val1); + + int val2; + index = co_await select(std::pair{coro_io::post([] { + }), + [&](auto) { + std::cout << "post1\n"; + val2 = 0; + }}, + std::pair{coro_io::post([] { + }), + [&](auto) { + std::cout << "post2\n"; + val2 = 1; + }}); + CHECK(index == val2); + + co_await async_send(ch1, 43); + auto lazy = coro_io::post([] { + }); + + int val3 = -1; + index = co_await select(std::pair{async_receive(ch1), + [&](auto result) { + val3 = result.value().second; + }}, + std::pair{std::move(lazy), [&](auto) { + val3 = 0; + }}); + + if (index == 0) { + CHECK(val3 == 43); + } + else if (index == 1) { + CHECK(val3 == 0); + } +} + +TEST_CASE("test select coro channel") { + using namespace coro_io; + async_simple::coro::syncAwait(test_select_channel()); + + auto ch = coro_io::create_channel(1000); + + async_simple::coro::syncAwait(coro_io::async_send(ch, 41)); + async_simple::coro::syncAwait(coro_io::async_send(ch, 42)); + + std::error_code ec; + int val; + std::tie(ec, val) = async_simple::coro::syncAwait(coro_io::async_receive(ch)); + CHECK(val == 41); + + std::tie(ec, val) = async_simple::coro::syncAwait(coro_io::async_receive(ch)); + CHECK(val == 42); +} + async_simple::coro::Lazy test_collect_all() { asio::io_context ioc; std::thread thd([&] {