diff --git a/include/cinatra/coro_http_server.hpp b/include/cinatra/coro_http_server.hpp index 162c2af6..c0ab57f9 100644 --- a/include/cinatra/coro_http_server.hpp +++ b/include/cinatra/coro_http_server.hpp @@ -220,6 +220,61 @@ class coro_http_server { } } + template + void set_websocket_proxy_handler(std::string url_path, + std::vector hosts, + coro_io::load_blance_algorithm type = + coro_io::load_blance_algorithm::random, + std::vector weights = {}, + Aspects &&...aspects) { + if (hosts.empty()) { + throw std::invalid_argument("not config hosts yet!"); + } + + auto channel = std::make_shared>( + coro_io::channel::create(hosts, {.lba = type}, + weights)); + + set_http_handler( + url_path, + [channel](coro_http_request &req, + coro_http_response &resp) -> async_simple::coro::Lazy { + websocket_result result{}; + while (true) { + result = co_await req.get_conn()->read_websocket(); + if (result.ec) { + break; + } + + if (result.type == ws_frame_type::WS_CLOSE_FRAME) { + CINATRA_LOG_INFO << "close frame"; + break; + } + + co_await channel->send_request( + [&req, result]( + coro_http_client &client, + std::string_view host) -> async_simple::coro::Lazy { + auto r = + co_await client.write_websocket(std::string(result.data)); + if (r.net_err) { + co_return; + } + auto data = co_await client.read_websocket(); + if (data.net_err) { + co_return; + } + auto ec = co_await req.get_conn()->write_websocket( + std::string(result.data)); + if (ec) { + co_return; + } + }); + } + }, + std::forward(aspects)...); + } + void set_max_size_of_cache_files(size_t max_size = 3 * 1024 * 1024) { std::error_code ec; for (const auto &file : diff --git a/tests/test_coro_http_server.cpp b/tests/test_coro_http_server.cpp index f5296868..2e8a8b8c 100644 --- a/tests/test_coro_http_server.cpp +++ b/tests/test_coro_http_server.cpp @@ -1486,4 +1486,44 @@ TEST_CASE("test reverse proxy") { req_content_type::text); std::cout << resp_random.resp_body << "\n"; CHECK(!resp_random.resp_body.empty()); -} \ No newline at end of file +} + +TEST_CASE("test reverse proxy websocket") { + coro_http_server server(1, 9001); + server.set_http_handler( + "/ws_echo", + [](coro_http_request &req, + coro_http_response &resp) -> async_simple::coro::Lazy { + CHECK(req.get_content_type() == content_type::websocket); + websocket_result result{}; + while (true) { + result = co_await req.get_conn()->read_websocket(); + if (result.ec) { + break; + } + + auto ec = co_await req.get_conn()->write_websocket(result.data); + if (ec) { + break; + } + } + }); + server.async_start(); + + coro_http_server proxy_server(1, 9002); + proxy_server.set_websocket_proxy_handler("/ws_echo", + {"ws://127.0.0.1:9001/ws_echo"}); + proxy_server.async_start(); + std::this_thread::sleep_for(200ms); + + coro_http_client client{}; + auto r = async_simple::coro::syncAwait( + client.connect("ws://127.0.0.1:9002/ws_echo")); + CHECK(!r.net_err); + for (int i = 0; i < 10; i++) { + async_simple::coro::syncAwait(client.write_websocket("test websocket")); + auto data = async_simple::coro::syncAwait(client.read_websocket()); + std::cout << data.resp_body << "\n"; + CHECK(data.resp_body == "test websocket"); + } +}