diff --git a/example/main.cpp b/example/main.cpp index 92fc4c4b..ca13fc09 100644 --- a/example/main.cpp +++ b/example/main.cpp @@ -205,8 +205,7 @@ async_simple::coro::Lazy use_websocket() { assert(!result.net_err); auto data = co_await client.read_websocket(); assert(data.resp_body == "hello websocket"); - result = - co_await client.write_websocket("test again", /*need_mask = */ false); + result = co_await client.write_websocket("test again"); assert(!result.net_err); data = co_await client.read_websocket(); assert(data.resp_body == "test again"); diff --git a/include/cinatra/coro_http_client.hpp b/include/cinatra/coro_http_client.hpp index e68dc1f9..b23bdc4a 100644 --- a/include/cinatra/coro_http_client.hpp +++ b/include/cinatra/coro_http_client.hpp @@ -362,24 +362,36 @@ class coro_http_client : public std::enable_shared_from_this { } async_simple::coro::Lazy write_websocket( - const char *data, bool need_mask = true, opcode op = opcode::text) { + const char *data, opcode op = opcode::text) { std::string str(data); - co_return co_await write_websocket(std::span(str), need_mask, op); + co_return co_await write_websocket(str, op); } async_simple::coro::Lazy write_websocket( - std::string &data, bool need_mask = true, opcode op = opcode::text) { - co_return co_await write_websocket(std::span(data), need_mask, op); + const char *data, size_t size, opcode op = opcode::text) { + std::string str(data, size); + co_return co_await write_websocket(str, op); } async_simple::coro::Lazy write_websocket( - std::string &&data, bool need_mask = true, opcode op = opcode::text) { - co_return co_await write_websocket(std::span(data), need_mask, op); + std::string_view data, opcode op = opcode::text) { + std::string str(data); + co_return co_await write_websocket(str, op); + } + + async_simple::coro::Lazy write_websocket( + std::string &data, opcode op = opcode::text) { + co_return co_await write_websocket(std::span(data), op); + } + + async_simple::coro::Lazy write_websocket( + std::string &&data, opcode op = opcode::text) { + co_return co_await write_websocket(std::span(data), op); } template async_simple::coro::Lazy write_websocket( - Source source, bool need_mask = true, opcode op = opcode::text) { + Source source, opcode op = opcode::text) { resp_data data{}; websocket ws{}; @@ -399,7 +411,7 @@ class coro_http_client : public std::enable_shared_from_this { if (cinatra::gzip_codec::deflate( std::string(source.begin(), source.end()), dest_buf)) { std::span msg(dest_buf.data(), dest_buf.size()); - auto header = ws.encode_frame(msg, op, need_mask, true, true); + auto header = ws.encode_frame(msg, op, true, true); std::vector buffers; buffers.push_back(asio::buffer(header)); buffers.push_back(asio::buffer(dest_buf)); @@ -418,8 +430,23 @@ class coro_http_client : public std::enable_shared_from_this { } } else { -#endif - std::string encode_header = ws.encode_frame(source, op, need_mask); + std::string encode_header = ws.encode_frame(source, op, true); + std::vector buffers{ + asio::buffer(encode_header.data(), encode_header.size()), + asio::buffer(source.data(), source.size())}; + + auto [ec, _] = co_await async_write(buffers); + if (ec) { + data.net_err = ec; + data.status = 404; + } + } + else { + while (true) { + auto result = co_await source(); + + std::span msg(result.buf.data(), result.buf.size()); + std::string encode_header = ws.encode_frame(msg, op, result.eof); std::vector buffers{ asio::buffer(encode_header.data(), encode_header.size()), asio::buffer(source.data(), source.size())}; @@ -443,7 +470,7 @@ class coro_http_client : public std::enable_shared_from_this { if (cinatra::gzip_codec::deflate(std::string(result), dest_buf)) { std::span msg(dest_buf.data(), dest_buf.size()); std::string header = - ws.encode_frame(msg, op, need_mask, result.eof, true); + ws.encode_frame(msg, op, result.eof, true); std::vector buffers; buffers.push_back(asio::buffer(header)); buffers.push_back(asio::buffer(dest_buf)); @@ -466,7 +493,7 @@ class coro_http_client : public std::enable_shared_from_this { std::span msg(result.buf.data(), result.buf.size()); std::string encode_header = - ws.encode_frame(msg, op, need_mask, result.eof); + ws.encode_frame(msg, op, result.eof); std::vector buffers{ asio::buffer(encode_header.data(), encode_header.size()), asio::buffer(msg.data(), msg.size())}; @@ -492,7 +519,7 @@ class coro_http_client : public std::enable_shared_from_this { async_simple::coro::Lazy write_websocket_close( std::string msg = "") { - co_return co_await write_websocket(std::move(msg), false, opcode::close); + co_return co_await write_websocket(std::move(msg), opcode::close); } #ifdef BENCHMARK_TEST @@ -1940,7 +1967,7 @@ class coro_http_client : public std::enable_shared_from_this { auto close_str = ws.format_close_payload(close_code::normal, reason.data(), reason.size()); auto span = std::span(close_str); - std::string encode_header = ws.encode_frame(span, opcode::close, false); + std::string encode_header = ws.encode_frame(span, opcode::close, true); std::vector buffers{asio::buffer(encode_header), asio::buffer(reason)}; diff --git a/include/cinatra/coro_http_server.hpp b/include/cinatra/coro_http_server.hpp index 162c2af6..c0ab57f9 100644 --- a/include/cinatra/coro_http_server.hpp +++ b/include/cinatra/coro_http_server.hpp @@ -220,6 +220,61 @@ class coro_http_server { } } + template + void set_websocket_proxy_handler(std::string url_path, + std::vector hosts, + coro_io::load_blance_algorithm type = + coro_io::load_blance_algorithm::random, + std::vector weights = {}, + Aspects &&...aspects) { + if (hosts.empty()) { + throw std::invalid_argument("not config hosts yet!"); + } + + auto channel = std::make_shared>( + coro_io::channel::create(hosts, {.lba = type}, + weights)); + + set_http_handler( + url_path, + [channel](coro_http_request &req, + coro_http_response &resp) -> async_simple::coro::Lazy { + websocket_result result{}; + while (true) { + result = co_await req.get_conn()->read_websocket(); + if (result.ec) { + break; + } + + if (result.type == ws_frame_type::WS_CLOSE_FRAME) { + CINATRA_LOG_INFO << "close frame"; + break; + } + + co_await channel->send_request( + [&req, result]( + coro_http_client &client, + std::string_view host) -> async_simple::coro::Lazy { + auto r = + co_await client.write_websocket(std::string(result.data)); + if (r.net_err) { + co_return; + } + auto data = co_await client.read_websocket(); + if (data.net_err) { + co_return; + } + auto ec = co_await req.get_conn()->write_websocket( + std::string(result.data)); + if (ec) { + co_return; + } + }); + } + }, + std::forward(aspects)...); + } + void set_max_size_of_cache_files(size_t max_size = 3 * 1024 * 1024) { std::error_code ec; for (const auto &file : diff --git a/include/cinatra/websocket.hpp b/include/cinatra/websocket.hpp index 1b7ee2d2..99b7d1be 100644 --- a/include/cinatra/websocket.hpp +++ b/include/cinatra/websocket.hpp @@ -127,8 +127,8 @@ class websocket { return {msg_header_, header_length}; } - std::string encode_frame(std::span &data, opcode op, bool need_mask, - bool eof = true, bool need_compression = false) { + + std::string encode_frame(std::span &data, opcode op, bool eof, bool need_compression = false) { std::string header; /// Base header. frame_header hdr{}; @@ -177,11 +177,9 @@ class websocket { /// The mask is a 32-bit value. uint8_t mask[4] = {}; - if (need_mask) { - header[1] |= 0x80; - uint32_t random = (uint32_t)rand(); - memcpy(mask, &random, 4); - } + header[1] |= 0x80; + uint32_t random = (uint32_t)rand(); + memcpy(mask, &random, 4); size_t size = header.size(); header.resize(size + 4); diff --git a/tests/test_cinatra_websocket.cpp b/tests/test_cinatra_websocket.cpp index 26b58169..a1a61115 100644 --- a/tests/test_cinatra_websocket.cpp +++ b/tests/test_cinatra_websocket.cpp @@ -64,8 +64,7 @@ async_simple::coro::Lazy test_websocket(coro_http_client &client) { auto result = co_await client.write_websocket("hello websocket"); auto data = co_await client.read_websocket(); CHECK(data.resp_body == "hello websocket"); - co_await client.write_websocket("test again", /*need_mask = */ - false); + co_await client.write_websocket("test again"); data = co_await client.read_websocket(); CHECK(data.resp_body == "test again"); co_await client.write_websocket_close("ws close"); @@ -243,7 +242,8 @@ async_simple::coro::Lazy test_websocket() { co_return; } - co_await client.write_websocket("test2fdsaf", true, opcode::binary); + co_await client.write_websocket(std::string_view("test2fdsaf"), + opcode::binary); auto data = co_await client.read_websocket(); CHECK(data.resp_body == "test2fdsaf"); diff --git a/tests/test_coro_http_server.cpp b/tests/test_coro_http_server.cpp index 91dbb481..2e8a8b8c 100644 --- a/tests/test_coro_http_server.cpp +++ b/tests/test_coro_http_server.cpp @@ -802,6 +802,8 @@ TEST_CASE("test websocket with chunked") { break; } + std::cout << result.data.size() << "\n"; + if (result.data.size() < ws_chunk_size) { CHECK(result.data.size() == 24); CHECK(result.eof); @@ -841,7 +843,7 @@ TEST_CASE("test websocket with chunked") { }; async_simple::coro::syncAwait( - client.write_websocket(std::move(source_fn), true, opcode::binary)); + client.write_websocket(std::move(source_fn), opcode::binary)); auto data = async_simple::coro::syncAwait(client.read_websocket()); if (data.net_err) { @@ -912,16 +914,17 @@ TEST_CASE("test websocket") { auto lazy = []() -> async_simple::coro::Lazy { coro_http_client client{}; co_await client.connect("ws://127.0.0.1:9001/ws_echo"); - co_await client.write_websocket("test2fdsaf", true, opcode::binary); + co_await client.write_websocket(std::string_view("test2fdsaf"), + opcode::binary); auto data = co_await client.read_websocket(); CHECK(data.resp_body == "test2fdsaf"); co_await client.write_websocket("test_ws"); data = co_await client.read_websocket(); CHECK(data.resp_body == "test_ws"); - co_await client.write_websocket("PING", false, opcode::ping); + co_await client.write_websocket("PING", opcode::ping); data = co_await client.read_websocket(); CHECK(data.resp_body == "pong"); - co_await client.write_websocket("PONG", false, opcode::pong); + co_await client.write_websocket("PONG", opcode::pong); data = co_await client.read_websocket(); CHECK(data.resp_body == "ping"); co_await client.write_websocket_close("normal close"); @@ -1031,7 +1034,7 @@ TEST_CASE("test websocket binary data") { std::string short_str(127, 'A'); async_simple::coro::syncAwait( - client1->write_websocket(std::move(short_str), true, opcode::binary)); + client1->write_websocket(std::move(short_str), opcode::binary)); auto client2 = std::make_shared(); async_simple::coro::syncAwait( @@ -1039,7 +1042,7 @@ TEST_CASE("test websocket binary data") { std::string medium_str(65535, 'A'); async_simple::coro::syncAwait( - client2->write_websocket(std::move(medium_str), true, opcode::binary)); + client2->write_websocket(std::move(medium_str), opcode::binary)); auto client3 = std::make_shared(); async_simple::coro::syncAwait( @@ -1047,7 +1050,7 @@ TEST_CASE("test websocket binary data") { std::string long_str(65536, 'A'); async_simple::coro::syncAwait( - client3->write_websocket(std::move(long_str), true, opcode::binary)); + client3->write_websocket(std::move(long_str), opcode::binary)); async_simple::coro::syncAwait(client1->write_websocket_close()); async_simple::coro::syncAwait(client2->write_websocket_close()); @@ -1483,4 +1486,44 @@ TEST_CASE("test reverse proxy") { req_content_type::text); std::cout << resp_random.resp_body << "\n"; CHECK(!resp_random.resp_body.empty()); -} \ No newline at end of file +} + +TEST_CASE("test reverse proxy websocket") { + coro_http_server server(1, 9001); + server.set_http_handler( + "/ws_echo", + [](coro_http_request &req, + coro_http_response &resp) -> async_simple::coro::Lazy { + CHECK(req.get_content_type() == content_type::websocket); + websocket_result result{}; + while (true) { + result = co_await req.get_conn()->read_websocket(); + if (result.ec) { + break; + } + + auto ec = co_await req.get_conn()->write_websocket(result.data); + if (ec) { + break; + } + } + }); + server.async_start(); + + coro_http_server proxy_server(1, 9002); + proxy_server.set_websocket_proxy_handler("/ws_echo", + {"ws://127.0.0.1:9001/ws_echo"}); + proxy_server.async_start(); + std::this_thread::sleep_for(200ms); + + coro_http_client client{}; + auto r = async_simple::coro::syncAwait( + client.connect("ws://127.0.0.1:9002/ws_echo")); + CHECK(!r.net_err); + for (int i = 0; i < 10; i++) { + async_simple::coro::syncAwait(client.write_websocket("test websocket")); + auto data = async_simple::coro::syncAwait(client.read_websocket()); + std::cout << data.resp_body << "\n"; + CHECK(data.resp_body == "test websocket"); + } +}