From 8cc8cab5b58b4a1dcc57799c9388fc2f7fbf953c Mon Sep 17 00:00:00 2001 From: helintong Date: Wed, 17 Apr 2024 17:55:25 +0800 Subject: [PATCH] feat: add websocket compression for coro_client --- include/cinatra/coro_http_client.hpp | 109 ++++++++++++++++++++++- include/cinatra/coro_http_connection.hpp | 4 +- include/cinatra/websocket.hpp | 58 +++--------- 3 files changed, 121 insertions(+), 50 deletions(-) diff --git a/include/cinatra/coro_http_client.hpp b/include/cinatra/coro_http_client.hpp index 567a2c9d..15bf9302 100644 --- a/include/cinatra/coro_http_client.hpp +++ b/include/cinatra/coro_http_client.hpp @@ -21,6 +21,9 @@ #include "async_simple/Unit.h" #include "async_simple/coro/FutureAwaiter.h" #include "async_simple/coro/Lazy.h" +#ifdef CINATRA_ENABLE_GZIP +#include "gzip.hpp" +#endif #include "cinatra_log_wrapper.hpp" #include "http_parser.hpp" #include "multipart.hpp" @@ -319,7 +322,7 @@ class coro_http_client : public std::enable_shared_from_this { void set_ws_sec_key(std::string sec_key) { ws_sec_key_ = std::move(sec_key); } - async_simple::coro::Lazy async_ws_connect(std::string uri) { + async_simple::coro::Lazy async_ws_connect(std::string uri, bool enable_ws_deflate = false) { resp_data data{}; auto [r, u] = handle_uri(data, uri); if (!r) { @@ -327,6 +330,8 @@ class coro_http_client : public std::enable_shared_from_this { co_return false; } + enable_ws_deflate_ = enable_ws_deflate; + req_context<> ctx{}; if (u.is_websocket()) { // build websocket http header @@ -337,10 +342,33 @@ class coro_http_client : public std::enable_shared_from_this { } add_header("Sec-WebSocket-Key", ws_sec_key_); add_header("Sec-WebSocket-Version", "13"); +#ifdef CINATRA_ENABLE_GZIP + add_header("Sec-WebSocket-Extensions", "permessage-deflate; client_max_window_bits"); +#endif } data = co_await async_request(std::move(uri), http_method::GET, std::move(ctx)); +#ifdef CINATRA_ENABLE_GZIP + if (enable_ws_deflate_) { + for (auto c : data.resp_headers) { + std::cout << c.name << " value is: " << c.value << std::endl; + if (c.name == "Sec-WebSocket-Extensions") { + std::cout << "have extensions\n"; + if (c.value.find("permessage-deflate;") != std::string::npos) { + std::cout << "support deflate extensions\n"; + is_server_support_ws_deflate_ = true; + } + else { + std::cout << "not support deflate extensions\n"; + is_server_support_ws_deflate_ = false; + } + break; + } + } + } +#endif + async_read_ws().start([](auto &&) { }); co_return !data.net_err; @@ -376,6 +404,30 @@ class coro_http_client : public std::enable_shared_from_this { } if constexpr (is_span_v) { +#ifdef CINATRA_ENABLE_GZIP + if (enable_ws_deflate_ && is_server_support_ws_deflate_) { + std::string dest_buf; + if (cinatra::gzip_codec::deflate(std::string(source.begin(), source.end()), dest_buf)) { + std::span msg(dest_buf.data(), dest_buf.size()); + auto header = ws.encode_frame(msg, op, need_mask, true, true); + std::vector buffers; + buffers.push_back(asio::buffer(header)); + buffers.push_back(asio::buffer(dest_buf)); + + auto [ec, sz] = co_await async_write(buffers); + if (ec) { + data.net_err = ec; + data.status = 404; + } + } + else { + CINATRA_LOG_ERROR << "compuress data error, data: " << std::string(source.begin(), source.end()); + data.net_err = std::make_error_code(std::errc::protocol_error); + data.status = 404; + } + } + else { +#endif std::string encode_header = ws.encode_frame(source, op, need_mask); std::vector buffers{ asio::buffer(encode_header.data(), encode_header.size()), @@ -386,11 +438,40 @@ class coro_http_client : public std::enable_shared_from_this { data.net_err = ec; data.status = 404; } +#ifdef CINATRA_ENABLE_GZIP + } +#endif } else { while (true) { auto result = co_await source(); +#ifdef CINATRA_ENABLE_GZIP + if (enable_ws_deflate_ && is_server_support_ws_deflate_) { + std::string dest_buf; + if (cinatra::gzip_codec::deflate(std::string(result), dest_buf)) { + std::span msg(dest_buf.data(), dest_buf.size()); + std::string header = + ws.encode_frame(msg, op, need_mask, result.eof, true); + std::vector buffers; + buffers.push_back(asio::buffer(header)); + buffers.push_back(asio::buffer(dest_buf)); + + auto [ec, sz] = co_await async_write(buffers); + if (ec) { + data.net_err = ec; + data.status = 404; + } + } + else { + CINATRA_LOG_ERROR << "compuress data error, data: " << std::string(source.begin(), source.end()); + data.net_err = std::make_error_code(std::errc::protocol_error); + data.status = 404; + } + } + else { +#endif + std::span msg(result.buf.data(), result.buf.size()); std::string encode_header = ws.encode_frame(msg, op, need_mask, result.eof); @@ -409,6 +490,9 @@ class coro_http_client : public std::enable_shared_from_this { break; } } +#ifdef CINATRA_ENABLE_GZIP + } +#endif } co_return data; @@ -1848,9 +1932,27 @@ class coro_http_client : public std::enable_shared_from_this { data_ptr += sizeof(uint16_t); } } +#ifdef CINATRA_ENABLE_GZIP + if (!is_close_frame && is_server_support_ws_deflate_ && enable_ws_deflate_) { + std::string out; + if (!cinatra::gzip_codec::inflate(std::string(data_ptr), out)) + { + CINATRA_LOG_ERROR << "uncompuress data error"; + data.status = 404; + data.net_err = std::make_error_code(std::errc::protocol_error); + break; + } + data.status = 200; + data.resp_body = {out.data(), out.size()}; + } + else { +#endif data.status = 200; data.resp_body = {data_ptr, payload_len}; +#ifdef CINATRA_ENABLE_GZIP + } +#endif read_buf.consume(read_buf.size()); header_size = 2; @@ -2042,6 +2144,11 @@ class coro_http_client : public std::enable_shared_from_this { std::string resp_chunk_str_; std::span out_buf_; + bool enable_ws_deflate_ = false; +#ifdef CINATRA_ENABLE_GZIP + bool is_server_support_ws_deflate_ = false; +#endif + #ifdef BENCHMARK_TEST std::string req_str_; bool stop_bench_ = false; diff --git a/include/cinatra/coro_http_connection.hpp b/include/cinatra/coro_http_connection.hpp index 1d3f07a9..8ec4494c 100644 --- a/include/cinatra/coro_http_connection.hpp +++ b/include/cinatra/coro_http_connection.hpp @@ -564,13 +564,13 @@ class coro_http_connection std::string_view msg, opcode op = opcode::text) { #ifdef CINATRA_ENABLE_GZIP std::string dest_buf; - if (is_client_ws_compressed_ && data_length > 0) { + if (is_client_ws_compressed_ && msg.size() > 0) { if (!cinatra::gzip_codec::deflate(std::string(msg), dest_buf)) { CINATRA_LOG_ERROR << "compuress data error, data: " << msg; co_return std::make_error_code(std::errc::protocol_error); } - auto header = ws_.format_compressed_header(dest_buf.length(), op); + auto header = ws_.format_header(dest_buf.length(), op, true); std::vector buffers; buffers.push_back(asio::buffer(header)); buffers.push_back(asio::buffer(dest_buf)); diff --git a/include/cinatra/websocket.hpp b/include/cinatra/websocket.hpp index 1b33bd43..8aaee243 100644 --- a/include/cinatra/websocket.hpp +++ b/include/cinatra/websocket.hpp @@ -121,61 +121,22 @@ class websocket { return ws_frame_type::WS_BINARY_FRAME; } - std::string format_header(size_t length, opcode code) { - size_t header_length = encode_header(length, code); + std::string format_header(size_t length, opcode code, bool is_compressed = false) { + size_t header_length = encode_header(length, code, is_compressed); return {msg_header_, header_length}; } - std::string format_compressed_header(size_t data_length, opcode code) { - - std::string destbuf; - char first_two_bytes[2] = { 0 }; - //FIN - first_two_bytes[0] |= 0x80; - - first_two_bytes[0] |= code; - - const char compress_flag = 0x40; - first_two_bytes[0] |= compress_flag; - - //mask = 0; - std::string send_data; - - if (data_length < 126) - { - first_two_bytes[1] = data_length; - send_data.append(first_two_bytes, 2); - } - else if (data_length <= UINT16_MAX) - { - first_two_bytes[1] = 126; - char extended_playload_length[2] = { 0 }; - uint16_t tmp = htons(data_length); - memcpy(&extended_playload_length, &tmp, 2); - send_data.append(first_two_bytes, 2); - send_data.append(extended_playload_length, 2); - } - else - { - first_two_bytes[1] = 127; - char extended_playload_length[8] = {0}; - uint64_t tmp = htobe64((uint64_t)data_length); - memcpy(&extended_playload_length, &tmp, 8); - send_data.append(first_two_bytes, 2); - send_data.append(extended_playload_length, 8); - } - - return send_data; - } - std::string encode_frame(std::span &data, opcode op, bool need_mask, - bool eof = true) { + bool eof = true, bool need_compression = false) { std::string header; /// Base header. frame_header hdr{}; hdr.fin = eof; hdr.rsv1 = 0; - hdr.rsv2 = 0; + if (need_compression) + hdr.rsv2 = 1; + else + hdr.rsv2 = 0; hdr.rsv3 = 0; hdr.opcode = static_cast(op); hdr.mask = 1; @@ -269,7 +230,7 @@ class websocket { opcode get_opcode() { return (opcode)msg_opcode_; } private: - size_t encode_header(size_t length, opcode code) { + size_t encode_header(size_t length, opcode code, bool is_compressed = false) { size_t header_length; if (length < 126) { @@ -293,6 +254,9 @@ class websocket { msg_header_[0] |= code; } + if (is_compressed) + msg_header_[0] |= 0x40; + return header_length; }